Compare commits

..

1 Commits

Author SHA1 Message Date
3f3730b657 Preliminary implementation for the vocos model. 2024-02-14 22:16:09 +01:00
94 changed files with 1179 additions and 10450 deletions

View File

@ -19,7 +19,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.4.1"
version = "0.4.0"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -31,18 +31,17 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.4.1" }
candle-datasets = { path = "./candle-datasets", version = "0.4.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.1" }
candle-kernels = { path = "./candle-kernels", version = "0.4.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.1" }
candle-nn = { path = "./candle-nn", version = "0.4.1" }
candle-onnx = { path = "./candle-onnx", version = "0.4.1" }
candle-transformers = { path = "./candle-transformers", version = "0.4.1" }
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
candle-nn = { path = "./candle-nn", version = "0.4.0" }
candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }

View File

@ -63,8 +63,6 @@ We also provide a some command line based examples using state of the art models
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
Deepmind.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports
@ -76,10 +74,9 @@ We also provide a some command line based examples using state of the art models
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/) and
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
performance.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
@ -110,12 +107,7 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
model using residual vector quantization.
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
text-to-speech.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
@ -195,10 +187,9 @@ If you have an addition to this list, please submit a pull request.
- Language Models.
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- Falcon.
- StarCoder, StarCoder2.
- StarCoder.
- Phi 1, 1.5, and 2.
- Mamba, Minimal Mamba
- Gemma 2b and 7b.
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
@ -206,7 +197,7 @@ If you have an addition to this list, please submit a pull request.
- Bert.
- Yi-6B and Yi-34B.
- Qwen1.5.
- RWKV v5 and v6.
- RWKV.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
@ -216,22 +207,18 @@ If you have an addition to this list, please submit a pull request.
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Whisper (multi-lingual support).
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- TrOCR.
- Audio.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
ConvNeXTv2.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- SegFormer.
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.

View File

@ -5,32 +5,25 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Module, Tensor};
use candle_core::quantized::{QMatMul, QTensor};
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
let q = QMatMul::from_qtensor(q)?;
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
let res_q_cuda = q.forward(&x)?;
println!("{res_q_cuda}");
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;
let res_q_cpu = q_cpu.forward(&x_cpu)?;
println!("{res_q_cpu}");
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
.abs()?
.flatten_all()?
.max(0)?;
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
println!("{out_t}");
let in_t = in_t.to_device(&Device::Cpu)?;
let k_t = k_t.to_device(&Device::Cpu)?;
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
.sqr()?
.sum_all()?;
println!("{diff}");
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1, 1, 1)?;
println!("{res:?}");
Ok(())
}

View File

@ -113,7 +113,7 @@ impl Tensor {
| Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. }
| Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
@ -250,7 +250,6 @@ impl Tensor {
out_padding,
*stride,
*dilation,
/* groups */ 1,
)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
@ -348,18 +347,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
}
Op::UpsampleNearest1D { arg, target_size } => {
let (_n, c, size) = arg.dims3()?;
if target_size % size != 0 {
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale = target_size / size;
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
Op::UpsampleNearest2D {
arg,
target_h,

View File

@ -187,16 +187,36 @@ impl Tensor {
}
}
fn conv_transpose1d_single_group(
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
params: &ParamsConvTranspose1D,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d(
self.layout(),
&kernel.storage(),
kernel.layout(),
params,
&params,
)?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg,
@ -210,49 +230,6 @@ impl Tensor {
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
}
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_in_k, c_out, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
if c_in % groups != 0 {
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in: c_in / groups,
padding,
output_padding,
stride,
dilation,
};
if groups == 1 {
self.conv_transpose1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage =
self.storage()

View File

@ -1263,7 +1263,6 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0;
let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out();
@ -2575,7 +2574,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
}
}
@ -2584,7 +2583,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
}
}
@ -2601,7 +2600,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
}
}

View File

@ -129,15 +129,6 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
}
}
impl<M: Module> Module for Option<&M> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
None => Ok(xs.clone()),
Some(m) => m.forward(xs),
}
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {

View File

@ -738,7 +738,6 @@ impl BackendStorage for MetalStorage {
("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT,
("utanh", DType::F32) => strided::tanh::FLOAT,
("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF,
@ -755,7 +754,6 @@ impl BackendStorage for MetalStorage {
("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF,
("utanh", DType::F16) => strided::tanh::HALF,
(name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
}
@ -829,9 +827,9 @@ impl BackendStorage for MetalStorage {
layout.start_offset() * self.dtype.size_in_bytes(),
),
&t.buffer,
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&buffer,
)
.map_err(MetalError::from)?;
@ -1266,7 +1264,7 @@ impl BackendStorage for MetalStorage {
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
blit.end_encoding();
} else {
let src_shape = src_l.shape();
@ -1638,7 +1636,7 @@ impl BackendDevice for MetalDevice {
min as f32,
max as f32,
shape.elem_count(),
&self.seed.lock().unwrap(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
@ -1669,7 +1667,7 @@ impl BackendDevice for MetalDevice {
mean as f32,
stddev as f32,
shape.elem_count(),
&self.seed.lock().unwrap(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;

View File

@ -132,10 +132,7 @@ pub enum Op {
stride: (usize, usize),
},
UpsampleNearest1D {
arg: Tensor,
target_size: usize,
},
UpsampleNearest1D(Tensor),
UpsampleNearest2D {
arg: Tensor,
target_h: usize,

View File

@ -42,7 +42,7 @@ pub enum OpCode {
Stop = b'.',
NewObj = 0x81,
EmptyList = b']',
BinFloat = b'G',
BinFloat = b'g',
Append = b'a',
Appends = b'e',
}
@ -462,10 +462,7 @@ impl Stack {
self.push(Object::Int(arg))
}
OpCode::BinFloat => {
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
let arg = r.read_f64::<byteorder::BigEndian>()?;
let arg = r.read_f64::<LittleEndian>()?;
self.push(Object::Float(arg))
}
OpCode::BinUnicode => {

View File

@ -1,343 +0,0 @@
use super::{GgmlDType, QStorage};
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use cudarc::driver::{CudaSlice, DeviceSlice};
pub struct QCudaStorage {
data: CudaSlice<u8>,
dtype: GgmlDType,
device: CudaDevice,
}
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
fn dequantize(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
GgmlDType::Q5_0 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_0",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q5_1 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_1",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mut_mal_vec(
data: &CudaSlice<u8>,
y: &cudarc::driver::CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
shared_mem_bytes: 0,
};
let params = (data, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
Ok(QCudaStorage {
data,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
let fast_kernel = matches!(
self.dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
use crate::quantized::k_quants::GgmlType;
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => {
let slice =
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
out.copy_from_slice(slice)
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
self.device
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.dtoh_sync_copy(data).w()?
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.len()
}
pub fn fwd(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
}
}
}
impl QCudaStorage {
fn dequantize_matmul_vec(
&self,
self_shape: &crate::Shape,
rhs: &CudaStorage,
rhs_l: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
let (nrows, ncols) = self_shape.dims2()?;
let rhs = rhs.as_cuda_slice::<f32>()?;
let rhs = match rhs_l.contiguous_offsets() {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (with_batch, k) = match rhs_l.shape().dims() {
[1, 1, k] => (true, k),
[1, k] => (false, k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
if ncols != *k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out =
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
let out_shape = if with_batch {
vec![1, 1, nrows]
} else {
vec![1, nrows]
};
Ok((out, out_shape.into()))
}
fn dequantize_matmul(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
use crate::backend::BackendStorage;
let (n, k) = self_shape.dims2()?;
let (b, m, k2) = match layout.shape().dims() {
&[b, m, k2] => (b, m, k2),
&[m, k2] => (1, m, k2),
s => crate::bail!("unexpected shape for input {s:?}"),
};
if k2 != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0);
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
Ok((out, out_shape.into()))
}
}
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
slice.to_vec()
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &CudaDevice,
data: &[T],
) -> Result<super::QStorage> {
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let data = device.htod_sync_copy(data).w()?;
Ok(QStorage::Cuda(QCudaStorage {
data,
device: device.clone(),
dtype: T::DTYPE,
}))
}

View File

@ -1,50 +0,0 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{CudaDevice, CudaStorage, Error, Result};
pub struct QCudaStorage {
dtype: GgmlDType,
device: CudaDevice,
}
impl QCudaStorage {
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &CudaStorage,
_layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
Err(Error::NotCompiledWithCudaSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &CudaDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -41,10 +41,3 @@ impl QMetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &MetalDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -1,5 +1,7 @@
//! Support for the GGML file format.
#[cfg(feature = "metal")]
use super::metal::load_quantized_metal;
use super::{k_quants, GgmlDType, QStorage};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt};
@ -128,8 +130,13 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device {
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
#[cfg(feature = "metal")]
Device::Metal(metal) => load_quantized_metal(metal, data)?,
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal backend requires `metal` feature")
}
device => unimplemented!("Implement quantized tensor for device {device:?}"),
};
super::QTensor::new(data, dims)
}

View File

@ -34,8 +34,6 @@ impl QMetalStorage {
}
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
use crate::quantized::k_quants::GgmlType;
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
@ -45,62 +43,81 @@ impl QMetalStorage {
blit.end_encoding();
self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => {
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
f32::to_float(&vec, &mut out)?;
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ2K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ3K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ4K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ5K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ6K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
let vec: Vec<crate::quantized::BlockQ8K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
@ -175,7 +192,7 @@ impl QMetalStorage {
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
device: &MetalDevice,
data: &[T],
) -> Result<QStorage> {

View File

@ -4,7 +4,6 @@ use std::borrow::Cow;
#[cfg(target_feature = "avx")]
pub mod avx;
mod dummy_cuda;
mod dummy_metal;
pub mod ggml_file;
pub mod gguf_file;
@ -15,13 +14,6 @@ pub mod metal;
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(not(feature = "cuda"))]
mod cuda {
pub use super::dummy_cuda::*;
}
#[cfg(target_feature = "neon")]
pub mod neon;
#[cfg(target_feature = "simd128")]
@ -47,9 +39,8 @@ impl Device {
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
Ok(QStorage::Metal(storage))
}
Device::Cuda(cuda) => {
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
Ok(QStorage::Cuda(storage))
Device::Cuda(_cuda) => {
crate::bail!("Cuda ggml quantization not supported");
}
}
}
@ -58,7 +49,6 @@ impl Device {
pub enum QStorage {
Cpu(Box<dyn QuantizedType>),
Metal(metal::QMetalStorage),
Cuda(cuda::QCudaStorage),
}
impl QStorage {
@ -66,7 +56,6 @@ impl QStorage {
match self {
QStorage::Cpu(storage) => storage.block_size(),
QStorage::Metal(storage) => storage.dtype().block_size(),
QStorage::Cuda(storage) => storage.dtype().block_size(),
}
}
@ -74,7 +63,6 @@ impl QStorage {
match self {
QStorage::Cpu(storage) => storage.dtype(),
QStorage::Metal(storage) => storage.dtype(),
QStorage::Cuda(storage) => storage.dtype(),
}
}
@ -82,7 +70,6 @@ impl QStorage {
match self {
QStorage::Cpu(_storage) => Device::Cpu,
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
}
}
@ -90,7 +77,6 @@ impl QStorage {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
}
}
@ -100,7 +86,6 @@ impl QStorage {
storage.from_float(src.as_slice::<f32>()?)?;
}
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"),
}
Ok(())
@ -110,7 +95,6 @@ impl QStorage {
match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
}
}
@ -122,7 +106,7 @@ impl QStorage {
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data))
}
QStorage::Metal(_) | QStorage::Cuda(_) => {
QStorage::Metal(_storage) => {
crate::bail!("not implemented");
}
}
@ -440,7 +424,7 @@ impl crate::CustomOp1 for QTensor {
#[allow(clippy::infallible_destructuring_match)]
let self_storage = match &self.storage {
QStorage::Cpu(storage) => storage,
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
QStorage::Metal(_) => crate::bail!("Invalid storage"),
};
let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
@ -460,18 +444,6 @@ impl crate::CustomOp1 for QTensor {
};
self_storage.fwd(&self.shape, storage, layout)
}
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Cuda(cuda) => cuda,
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
}
impl crate::Module for QMatMul {

View File

@ -352,10 +352,6 @@ impl Storage {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),

View File

@ -1015,7 +1015,7 @@ impl Tensor {
/// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let storage = self
.storage()
.upsample_nearest1d(self.layout(), target_size)?;

View File

@ -18,9 +18,6 @@ w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape)
print(res)
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
print(res.shape)
print(res)
*/
fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new(
@ -53,7 +50,7 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -62,17 +59,6 @@ fn conv1d(dev: &Device) -> Result<()> {
4.7076, -5.9745, -0.8276, 1.621
],
);
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
assert_eq!(res.dims(), [1, 4, 7]);
assert_eq!(
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
[
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
]
);
Ok(())
}

View File

@ -283,38 +283,19 @@ fn unary_grad(device: &Device) -> Result<()> {
[1.0881, 0.9277, 1.0527, 0.5747],
);
if device.is_cpu() {
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
let y = x.interpolate1d(12)?.reshape(36)?;
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
33., 34., 35., 36.,
],
device,
)?;
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(grad_x, 4)?,
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
);
}
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
35., 36.,
1_f32, 02., 03., 04., 05., 06.,
07., 08., 09., 10., 11., 12.,
13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
],
device,
)?;
@ -345,11 +326,15 @@ fn unary_grad(device: &Device) -> Result<()> {
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
35., 36.,
1_f32, 02., 03., 04., 05., 06.,
07., 08., 09., 10., 11., 12.,
13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
],
device,
)?;

View File

@ -178,6 +178,10 @@ test_device!(
);
fn quantize_q4_0(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
@ -205,6 +209,10 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
}
fn quantize_q4_1(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
@ -231,6 +239,10 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
}
fn quantize_q5_0(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
@ -257,6 +269,10 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
}
fn quantize_q5_1(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
@ -357,6 +373,10 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
}
fn quantize_q2k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q2K;
let src = get_test_vector2(0.5, 1024, device)?;
@ -391,6 +411,10 @@ fn quantize_q2k(device: &Device) -> Result<()> {
}
fn quantize_q3k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q3K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -424,6 +448,10 @@ fn quantize_q3k(device: &Device) -> Result<()> {
}
fn quantize_q4k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q4K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -457,6 +485,10 @@ fn quantize_q4k(device: &Device) -> Result<()> {
}
fn quantize_q5k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q5K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -490,6 +522,10 @@ fn quantize_q5k(device: &Device) -> Result<()> {
}
fn quantize_q6k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q6K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -523,6 +559,10 @@ fn quantize_q6k(device: &Device) -> Result<()> {
}
fn quantize_q8k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q8K;
let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -738,6 +778,10 @@ macro_rules! quantized_matmul {
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
fn $fn_name(device: &Device) -> Result<()> {
if device.is_cuda() {
// TODO Enable Cuda GGML sometime maybe.
return Ok(());
}
test_matmul(device, (1, 3, 4, 256), $dtype)?;
Ok(())
}

View File

@ -12,7 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { workspace = true }
candle-datasets = { workspace = true, optional = true }
candle-datasets = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
@ -30,7 +30,7 @@ rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
symphonia = { version = "0.5.3", features = ["all"] }
tokenizers = { workspace = true, features = ["onig"] }
cpal= { version = "0.15.2", optional = true }
@ -80,26 +80,6 @@ required-features = ["onnx"]
name = "onnx_basics"
required-features = ["onnx"]
[[example]]
name = "whisper"
required-features = ["symphonia"]
[[example]]
name = "whisper-microphone"
required-features = ["microphone"]
[[example]]
name = "mnist-training"
required-features = ["candle-datasets"]
[[example]]
name = "llama2-c"
required-features = ["candle-datasets"]
[[example]]
name = "encodec"
required-features = ["symphonia"]
[[example]]
name = "metavoice"
required-features = ["symphonia"]

View File

@ -1 +0,0 @@
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));

View File

@ -1,20 +0,0 @@
# candle-efficientvit
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://arxiv.org/abs/2305.07027).
This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example efficientvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 69.80%
unicycle, monocycle : 13.03%
bicycle-built-for-two, tandem bicycle, tandem: 9.28%
crash helmet : 2.25%
alp : 0.46%
```

View File

@ -1,99 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::efficientvit;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
M0,
M1,
M2,
M3,
M4,
M5,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::M0 => "m0",
Self::M1 => "m1",
Self::M2 => "m2",
Self::M3 => "m3",
Self::M4 => "m4",
Self::M5 => "m5",
};
format!("timm/efficientvit_{}.r224_in1k", name)
}
fn config(&self) -> efficientvit::Config {
match self {
Self::M0 => efficientvit::Config::m0(),
Self::M1 => efficientvit::Config::m1(),
Self::M2 => efficientvit::Config::m2(),
Self::M3 => efficientvit::Config::m3(),
Self::M4 => efficientvit::Config::m4(),
Self::M5 => efficientvit::Config::m5(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::M0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = efficientvit::efficientvit(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,20 +0,0 @@
# candle-endocec
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
compression model using an encoder/decoder architecture with residual vector
quantization.
## Running one example
```bash
cargo run --example encodec --features symphonia --release -- code-to-audio \
candle-examples/examples/encodec/jfk-codes.safetensors \
jfk.wav
```
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
an output wav file containing the audio data. Instead of `code-to-audio` one
can use:
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
containing EnCodec tokens for the input audio file.

View File

@ -1,143 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::encodec::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some encodec tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
out_file: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let config = Config::default();
let model = Model::new(&config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
codes
}
Action::AudioToCode | Action::AudioToAudio => {
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
}
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
};
println!("codes shape: {:?}", codes.shape());
match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = model.decode(&codes)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
Ok(())
}

View File

@ -1,27 +0,0 @@
# candle-mistral: 2b and 7b LLMs from Google DeepMind
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
models published by Google Deepmind with a 2b and a 7b variant.
In order to use the example below, you have to accept the license on the
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
your access token via the [HuggingFace cli login
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
## Running the example
```bash
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
fn count_primes(max_n: usize) -> usize {
let mut primes = vec![true; max_n];
for i in 2..=max_n {
if primes[i] {
for j in i * i..max_n {
primes[j] = false;
}
}
}
primes.len()
}
```

View File

@ -1,256 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::gemma::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => match model_id.as_str() {
"7b-it" => "google/gemma-7b-it".to_string(),
"7b" => "google/gemma-7b".to_string(),
"2b-it" => "google/gemma-2b-it".to_string(),
"2b" => "google/gemma-2b".to_string(),
_ => model_id.to_string(),
},
None => "google/gemma-2b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -57,7 +57,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 10000)]
#[arg(long, default_value_t = 100)]
sample_len: usize,
/// Disable the key-value cache.
@ -120,7 +120,7 @@ fn main() -> Result<()> {
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, mut cache) = {
let (llama, tokenizer_filename, cache) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
@ -143,10 +143,11 @@ fn main() -> Result<()> {
}
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
};
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache)
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
@ -156,7 +157,6 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop");
print!("{prompt}");
@ -172,7 +172,7 @@ fn main() -> Result<()> {
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?;
let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
@ -190,16 +190,18 @@ fn main() -> Result<()> {
token_generated += 1;
tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
if Some(next_token) == eos_token_id {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(

View File

@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
use std::io::Write;
use tokenizers::Tokenizer;
use model::{Cache, Config, Llama};
use model::{Config, Llama};
use qmodel::QLlama;
use weights::TransformerWeights;
@ -160,10 +160,10 @@ enum Model {
}
impl Model {
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
match self {
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
Self::Llama(l) => Ok(l.forward(xs, pos)?),
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
}
}
}
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?;
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let tokens = match &args.pretokenized_dir {
None => {
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0, &mut cache)?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?);
}
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config, mut cache) = if is_gguf {
let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
&device,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
(model, config, cache)
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
} else if is_safetensors {
let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?);
(model, config, cache)
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?);
(model, config, cache)
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
};
println!("starting the inference loop");
@ -328,7 +328,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let start_gen = std::time::Instant::now();
for index in 0.. {
@ -338,7 +337,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos, &mut cache)?;
let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits
@ -354,14 +353,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed();
println!(
"\n{} tokens generated ({:.2} token/s)\n",

View File

@ -8,7 +8,6 @@ fn valid_loss(
model: &Llama,
args: &crate::TrainingCmd,
device: &Device,
cache: &mut Cache,
) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
@ -16,7 +15,7 @@ fn valid_loss(
let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) {
let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0, cache)?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sum_ce += loss.to_vec0::<f32>()? as f64;
cnt += 1;
@ -38,8 +37,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?;
let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
let params = candle_nn::ParamsAdamW {
lr: args.learning_rate,
..Default::default()
@ -47,14 +46,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0, &mut cache)?;
let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the
// validation loss.
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
let loss = valid_loss(&dataset, &model, args, &device)?;
println!("{batch_index} {loss}");
}
if batch_index > 0 && batch_index % 1000 == 0 {

View File

@ -1,18 +0,0 @@
# candle-metavoice
MetaVoice-1B is a text-to-speech model trained on 100K hours of speech, more
details on the [model
card](https://huggingface.co/metavoiceio/metavoice-1B-v0.1).
Note that the current candle implementation suffers from some limitations as of
2024-03-02:
- The speaker embeddings are hardcoded.
- The generated audio file quality is weaker than the Python implementation,
probably because of some implementation discrepancies.
## Run an example
```bash
cargo run --example metavoice --release -- \\
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
```

View File

@ -1,342 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::Parser;
use std::io::Write;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::encodec;
use candle_transformers::models::metavoice::{
adapters, gpt, speaker_encoder, tokenizers, transformer,
};
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::api::sync::Api;
use rand::{distributions::Distribution, SeedableRng};
pub const ENCODEC_NTOKENS: u32 = 1024;
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum ArgDType {
F32,
F16,
Bf16,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The guidance scale.
#[arg(long, default_value_t = 3.0)]
guidance_scale: f64,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)]
temperature: f64,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The maximum number of tokens to generate for the first stage.
#[arg(long, default_value_t = 2000)]
max_tokens: u64,
/// The output file using the wav format.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long)]
first_stage_meta: Option<String>,
#[arg(long)]
first_stage_weights: Option<String>,
#[arg(long)]
second_stage_weights: Option<String>,
#[arg(long)]
speaker_encoder_weights: Option<String>,
#[arg(long)]
encodec_weights: Option<String>,
/// The speaker embeddings, either an audio files in which case they are extracted, or a
/// safetensors file with the embeddings already extracted.
#[arg(long)]
spk_emb: Option<String>,
#[arg(long, default_value = "f32")]
dtype: ArgDType,
}
fn mel_filters() -> Result<Vec<f32>> {
let mel_bytes = include_bytes!("melfilters40.bytes").as_slice();
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
Ok(mel_filters)
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let device = candle_examples::device(args.cpu)?;
let api = Api::new()?;
let repo = api.model("lmz/candle-metavoice".to_string());
let first_stage_meta = match &args.first_stage_meta {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage.meta.json")?,
};
let first_stage_meta: serde_json::Value =
serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?;
let first_stage_tokenizer = match first_stage_meta.as_object() {
None => anyhow::bail!("not a json object"),
Some(j) => match j.get("tokenizer") {
None => anyhow::bail!("no tokenizer key"),
Some(j) => j,
},
};
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
let first_stage_weights = match &args.first_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage.safetensors")?,
};
let second_stage_weights = match &args.second_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("second_stage.safetensors")?,
};
let encodec_weights = match args.encodec_weights {
Some(w) => std::path::PathBuf::from(w),
None => Api::new()?
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let dtype = match args.dtype {
ArgDType::F32 => DType::F32,
ArgDType::F16 => DType::F16,
ArgDType::Bf16 => DType::BF16,
};
let first_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
let first_stage_config = transformer::Config::cfg1b_v0_1();
let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
let second_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
let second_stage_config = gpt::Config::cfg1b_v0_1();
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
let encodec_device = if device.is_metal() {
&candle::Device::Cpu
} else {
&device
};
let encodec_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
let encodec_config = encodec::Config::default();
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
println!("prompt: '{}'", args.prompt);
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
let mut tokens = prompt_tokens.clone();
println!("{tokens:?}");
let safetensors_embeddings = args
.spk_emb
.as_ref()
.map_or(true, |v| v.ends_with("safetensors"));
let spk_emb = if safetensors_embeddings {
let spk_emb_file = match &args.spk_emb {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("spk_emb.safetensors")?,
};
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
match spk_emb.get("spk_emb") {
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?,
}
} else {
let weights = match &args.speaker_encoder_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("speaker_encoder.safetensors")?,
};
let mel_filters = mel_filters()?;
let config = speaker_encoder::Config::cfg();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? };
let model = speaker_encoder::Model::new(config, vb)?;
let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?;
if sample_rate != 16_000 {
eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!")
}
model.embed_utterance(
&pcm,
&mel_filters,
/* rate */ 1.3,
/* min_c */ 0.75,
&device,
)?
};
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
// First stage generation.
for index in 0..args.max_tokens {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &device)?;
let input = Tensor::stack(&[&input, &input], 0)?;
let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
let logits0 = logits.i((0, 0))?;
let logits1 = logits.i((1, 0))?;
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
let logits = logits.to_dtype(DType::F32)?;
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
print!(".");
std::io::stdout().flush()?;
if next_token == 2048 {
break;
}
}
println!();
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
println!("text ids len: {}", text_ids.len());
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
// TODO: Use the config rather than hardcoding the offset here.
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
let mut hierarchies_in1 =
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
let mut hierarchies_in2 = [
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
ids2.as_slice(),
&[ENCODEC_NTOKENS],
]
.concat();
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
let logits = second_stage_model.forward(&in_x)?;
println!("sampling from logits...");
let mut codes = vec![];
for logits in logits.iter() {
let logits = logits.squeeze(0)?;
let (seq_len, _) = logits.dims2()?;
let mut codes_ = Vec::with_capacity(seq_len);
for step in 0..seq_len {
let logits = logits.i(step)?.to_dtype(DType::F32)?;
let logits = &(&logits / 1.0)?;
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
let sample = distr.sample(&mut rng) as u32;
codes_.push(sample)
}
codes.push(codes_)
}
let codes = Tensor::new(codes, &device)?.unsqueeze(0)?;
let codes = Tensor::cat(&[in_x, codes], 1)?;
println!("codes: {codes}");
let tilted_encodec = adapters::TiltedEncodec::new(ENCODEC_NTOKENS);
let codes = codes.i(0)?.to_vec2::<u32>()?;
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
println!("text_ids len: {:?}", text_ids.len());
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
println!("audio_ids shape: {:?}", audio_ids.shape());
let pcm = encodec_model.decode(&audio_ids)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
Ok(())
}

View File

@ -152,7 +152,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long)]

View File

@ -143,7 +143,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]

View File

@ -0,0 +1,580 @@
use crate::nn::conv1d_weight_norm;
use candle::{DType, IndexOp, Module, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
target_bandwidths: Vec<f64>,
sampling_rate: usize,
audio_channels: usize,
normalize: bool,
chunk_length_s: Option<usize>,
overlap: Option<usize>,
hidden_size: usize,
num_filters: usize,
num_residual_layers: usize,
upsampling_ratios: Vec<usize>,
norm_type: NormType,
kernel_size: usize,
last_kernel_size: usize,
residual_kernel_size: usize,
dilation_growth_rate: usize,
use_causal_conv: bool,
pad_mode: &'static str,
compress: usize,
num_lstm_layers: usize,
trim_right_ratio: f64,
codebook_size: usize,
codebook_dim: Option<usize>,
use_conv_shortcut: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
sampling_rate: 24_000,
audio_channels: 1,
normalize: false,
chunk_length_s: None,
overlap: None,
hidden_size: 128,
num_filters: 32,
num_residual_layers: 1,
upsampling_ratios: vec![8, 5, 4, 2],
norm_type: NormType::WeightNorm,
kernel_size: 7,
last_kernel_size: 7,
residual_kernel_size: 3,
dilation_growth_rate: 2,
use_causal_conv: true,
pad_mode: "reflect",
compress: 2,
num_lstm_layers: 2,
trim_right_ratio: 1.0,
codebook_size: 1024,
codebook_dim: None,
use_conv_shortcut: true,
}
}
}
impl Config {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
pub fn musicgen_small() -> Self {
Self {
audio_channels: 1,
chunk_length_s: None,
codebook_dim: Some(128),
codebook_size: 2048,
compress: 2,
dilation_growth_rate: 2,
hidden_size: 128,
kernel_size: 7,
last_kernel_size: 7,
norm_type: NormType::WeightNorm,
normalize: false,
num_filters: 64,
num_lstm_layers: 2,
num_residual_layers: 1,
overlap: None,
pad_mode: "reflect",
residual_kernel_size: 3,
sampling_rate: 32_000,
target_bandwidths: vec![2.2],
trim_right_ratio: 1.0,
upsampling_ratios: vec![8, 5, 4, 4],
use_causal_conv: false,
use_conv_shortcut: false,
}
}
fn codebook_dim(&self) -> usize {
self.codebook_dim.unwrap_or(self.codebook_size)
}
fn frame_rate(&self) -> usize {
let hop_length: usize = self.upsampling_ratios.iter().product();
(self.sampling_rate + hop_length - 1) / hop_length
}
fn num_quantizers(&self) -> usize {
let num = 1000f64
* self
.target_bandwidths
.last()
.expect("empty target_bandwidths");
(num as usize) / (self.frame_rate() * 10)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
#[derive(Debug)]
struct EncodecEuclideanCodebook {
inited: Tensor,
cluster_size: Tensor,
embed: Tensor,
embed_avg: Tensor,
}
impl EncodecEuclideanCodebook {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inited = vb.get(1, "inited")?;
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
let embed = vb.get(e_shape, "embed")?;
let embed_avg = vb.get(e_shape, "embed_avg")?;
Ok(Self {
inited,
cluster_size,
embed,
embed_avg,
})
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.embed.embedding(embed_ind)?;
Ok(quantize)
}
}
#[derive(Debug)]
struct EncodecVectorQuantization {
codebook: EncodecEuclideanCodebook,
}
impl EncodecVectorQuantization {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
Ok(Self { codebook })
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.codebook.decode(embed_ind)?;
let quantize = quantize.transpose(1, 2)?;
Ok(quantize)
}
}
#[derive(Debug)]
struct EncodecResidualVectorQuantizer {
layers: Vec<EncodecVectorQuantization>,
}
impl EncodecResidualVectorQuantizer {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("layers");
let layers = (0..cfg.num_quantizers())
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
}
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
if codes.dim(0)? != self.layers.len() {
candle::bail!(
"codes shape {:?} does not match the number of quantization layers {}",
codes.shape(),
self.layers.len()
)
}
for (i, layer) in self.layers.iter().enumerate() {
let quantized = layer.decode(&codes.i(i)?)?;
quantized_out = quantized.broadcast_add(&quantized_out)?;
}
Ok(quantized_out)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
#[derive(Debug)]
struct EncodecLSTM {
layers: Vec<candle_nn::LSTM>,
}
impl EncodecLSTM {
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("lstm");
let mut layers = vec![];
for layer_idx in 0..cfg.num_lstm_layers {
let config = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
layers.push(lstm)
}
Ok(Self { layers })
}
}
impl Module for EncodecLSTM {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
let mut xs = xs.clone();
for layer in self.layers.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
Ok(xs)
}
}
#[derive(Debug)]
struct EncodecConvTranspose1d {
weight_g: Tensor,
weight_v: Tensor,
bias: Tensor,
}
impl EncodecConvTranspose1d {
fn load(
in_c: usize,
out_c: usize,
k: usize,
_stride: usize,
vb: VarBuilder,
_cfg: &Config,
) -> Result<Self> {
let vb = &vb.pp("conv");
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
let bias = vb.get(out_c, "bias")?;
Ok(Self {
weight_g,
weight_v,
bias,
})
}
}
impl Module for EncodecConvTranspose1d {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
#[derive(Debug)]
struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
}
impl EncodecConv1d {
fn load(
in_c: usize,
out_c: usize,
kernel_size: usize,
stride: usize,
vb: VarBuilder,
cfg: &Config,
) -> Result<Self> {
let conv = match cfg.norm_type {
NormType::WeightNorm => conv1d_weight_norm(
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
NormType::None | NormType::TimeGroupNorm => conv1d(
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
})
}
}
impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
#[derive(Debug)]
struct EncodecResnetBlock {
block_conv1: EncodecConv1d,
block_conv2: EncodecConv1d,
shortcut: Option<EncodecConv1d>,
}
impl EncodecResnetBlock {
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
candle::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();
let block_conv1 =
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
layer.inc();
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
Some(conv)
} else {
None
};
Ok(Self {
block_conv1,
block_conv2,
shortcut,
})
}
}
impl Module for EncodecResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs.clone();
let xs = xs.elu(1.)?;
let xs = self.block_conv1.forward(&xs)?;
let xs = xs.elu(1.)?;
let xs = self.block_conv2.forward(&xs)?;
let xs = match &self.shortcut {
None => (xs + residual)?,
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
};
Ok(xs)
}
}
struct Layer<'a> {
vb: VarBuilder<'a>,
cnt: usize,
}
impl<'a> Layer<'a> {
fn new(vb: VarBuilder<'a>) -> Self {
Self { vb, cnt: 0 }
}
fn inc(&mut self) {
self.cnt += 1;
}
fn next(&mut self) -> VarBuilder {
let vb = self.vb.pp(&self.cnt.to_string());
self.cnt += 1;
vb
}
}
#[derive(Debug)]
struct EncodecEncoder {
init_conv: EncodecConv1d,
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
final_lstm: EncodecLSTM,
final_conv: EncodecConv1d,
}
impl EncodecEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let init_conv = EncodecConv1d::load(
cfg.audio_channels,
cfg.num_filters,
cfg.kernel_size,
1,
layer.next(),
cfg,
)?;
let mut sampling_layers = vec![];
let mut scaling = 1;
for &ratio in cfg.upsampling_ratios.iter().rev() {
let current_scale = scaling * cfg.num_filters;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale,
&[cfg.dilation_growth_rate.pow(j), 1],
layer.next(),
cfg,
)?;
resnets.push(resnet)
}
layer.inc(); // ELU
let conv1d = EncodecConv1d::load(
current_scale,
current_scale * 2,
ratio * 2,
ratio,
layer.next(),
cfg,
)?;
sampling_layers.push((resnets, conv1d));
scaling *= 2;
}
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters * scaling,
cfg.hidden_size,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
Ok(Self {
init_conv,
sampling_layers,
final_conv,
final_lstm,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?;
for (resnets, conv) in self.sampling_layers.iter() {
for resnet in resnets.iter() {
xs = xs.apply(resnet)?;
}
xs = xs.elu(1.0)?.apply(conv)?;
}
xs.apply(&self.final_lstm)?
.elu(1.0)?
.apply(&self.final_conv)
}
}
#[derive(Debug)]
struct EncodecDecoder {
init_conv: EncodecConv1d,
init_lstm: EncodecLSTM,
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
final_conv: EncodecConv1d,
}
impl EncodecDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
let init_conv = EncodecConv1d::load(
cfg.hidden_size,
cfg.num_filters * scaling,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
let mut sampling_layers = vec![];
for &ratio in cfg.upsampling_ratios.iter() {
let current_scale = scaling * cfg.num_filters;
layer.inc(); // ELU
let conv1d = EncodecConvTranspose1d::load(
current_scale,
current_scale / 2,
ratio * 2,
ratio,
layer.next(),
cfg,
)?;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale / 2,
&[cfg.dilation_growth_rate.pow(j), 1],
layer.next(),
cfg,
)?;
resnets.push(resnet)
}
sampling_layers.push((conv1d, resnets));
scaling /= 2;
}
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters,
cfg.audio_channels,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
Ok(Self {
init_conv,
init_lstm,
sampling_layers,
final_conv,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
for (conv, resnets) in self.sampling_layers.iter() {
xs = xs.elu(1.)?.apply(conv)?;
for resnet in resnets.iter() {
xs = xs.apply(resnet)?
}
}
xs.elu(1.)?.apply(&self.final_conv)
}
}
#[derive(Debug)]
pub struct EncodecModel {
encoder: EncodecEncoder,
decoder: EncodecDecoder,
quantizer: EncodecResidualVectorQuantizer,
}
impl EncodecModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
Ok(Self {
encoder,
decoder,
quantizer,
})
}
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}

View File

@ -10,7 +10,9 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod encodec_model;
mod musicgen_model;
mod nn;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};

View File

@ -1,9 +1,10 @@
use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder,
};
use candle_transformers::models::{encodec, t5};
use candle_transformers::models::t5;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)]
@ -371,7 +372,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)]
pub struct MusicgenForConditionalGeneration {
pub text_encoder: t5::T5EncoderModel,
pub audio_encoder: encodec::Model,
pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM,
cfg: GenConfig,
}
@ -380,42 +381,15 @@ pub struct MusicgenForConditionalGeneration {
pub struct GenConfig {
musicgen: Config,
t5: t5::Config,
encodec: encodec::Config,
encodec: crate::encodec_model::Config,
}
impl GenConfig {
pub fn small() -> Self {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
let encodec = encodec::Config {
audio_channels: 1,
chunk_length_s: None,
codebook_dim: Some(128),
codebook_size: 2048,
compress: 2,
dilation_growth_rate: 2,
hidden_size: 128,
kernel_size: 7,
last_kernel_size: 7,
norm_type: encodec::NormType::WeightNorm,
normalize: false,
num_filters: 64,
num_lstm_layers: 2,
num_residual_layers: 1,
overlap: None,
// This should be Reflect and not Replicate but Reflect does not work yet.
pad_mode: encodec::PadMode::Replicate,
residual_kernel_size: 3,
sampling_rate: 32_000,
target_bandwidths: vec![2.2],
trim_right_ratio: 1.0,
upsampling_ratios: vec![8, 5, 4, 4],
use_causal_conv: false,
use_conv_shortcut: false,
};
Self {
musicgen: Config::musicgen_small(),
t5: t5::Config::musicgen_small(),
encodec,
encodec: encodec_model::Config::musicgen_small(),
}
}
}
@ -427,7 +401,8 @@ impl MusicgenForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?;
let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
Ok(Self {
text_encoder,

View File

@ -0,0 +1,20 @@
use candle::Result;
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
pub fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}

View File

@ -212,14 +212,6 @@ struct Args {
#[arg(long)]
verbose_prompt: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
@ -369,7 +361,7 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let device = candle_examples::device(false)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
@ -495,20 +487,11 @@ fn main() -> anyhow::Result<()> {
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in prompt_tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);

View File

@ -2,8 +2,8 @@
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
with performance on par with transformer architectures. Several variants are
available, candle implements the v5 and v6 versions and can be used with
Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
available, candle implements the v5 version and can be used with Eagle 7B([blog
post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
```bash
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "

View File

@ -7,36 +7,13 @@ extern crate accelerate_src;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use candle_transformers::models::quantized_rwkv_v5::Model as Q5;
use candle_transformers::models::quantized_rwkv_v6::Model as Q6;
use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer};
use candle_transformers::models::rwkv_v6::Model as M6;
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
const EOS_TOKEN_ID: u32 = 261;
enum Model {
M5(M5),
Q5(Q5),
M6(M6),
Q6(Q6),
}
impl Model {
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
match self {
Self::M5(m) => m.forward(xs, state),
Self::Q5(m) => m.forward(xs, state),
Self::M6(m) => m.forward(xs, state),
Self::Q6(m) => m.forward(xs, state),
}
}
}
struct TextGeneration {
model: Model,
config: Config,
@ -106,9 +83,6 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == EOS_TOKEN_ID || next_token == 0 {
break;
}
print!("{}", self.tokenizer.decode(&[next_token])?);
std::io::stdout().flush()?;
@ -129,7 +103,6 @@ enum Which {
Eagle7b,
World1b5,
World3b,
World6_1b6,
}
impl std::fmt::Display for Which {
@ -144,7 +117,6 @@ impl Which {
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
Self::World3b => "RWKV/rwkv-5-world-3b",
Self::World6_1b6 => "paperfun/rwkv",
}
}
@ -152,7 +124,6 @@ impl Which {
match self {
Self::Eagle7b => "refs/pr/1",
Self::World1b5 | Self::World3b => "refs/pr/2",
Self::World6_1b6 => "main",
}
}
}
@ -205,9 +176,6 @@ struct Args {
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
quantized: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
@ -268,27 +236,7 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.quantized {
vec![match args.which {
Which::World1b5 => api
.model("lmz/candle-rwkv".to_string())
.get("world1b5-q4k.gguf")?,
Which::World3b => api
.model("lmz/candle-rwkv".to_string())
.get("world3b-q4k.gguf")?,
Which::Eagle7b => api
.model("lmz/candle-rwkv".to_string())
.get("eagle7b-q4k.gguf")?,
Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?,
}]
} else {
vec![match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => {
repo.get("model.safetensors")?
}
Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?,
}]
}
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
@ -297,21 +245,8 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let model = if args.quantized {
let filename = &filenames[0];
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?),
Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?),
}
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?),
Which::World6_1b6 => Model::M6(M6::new(&config, vb)?),
}
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(

View File

@ -1,28 +0,0 @@
# candle-segformer
- [HuggingFace Segformer Model Card][segformer]
- [`mit-b0` - An encoder only pretrained model][encoder]
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
## How to run the example
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
```bash
# run the image classification task
cargo run --example segformer classify <path-to-image>
# run the segmentation task
cargo run --example segformer segment <path-to-image>
```
Example output for classification:
```text
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
label: hamburger
```
[pr]: https://github.com/huggingface/candle/pull/1617
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
[encoder]: https://huggingface.co/nvidia/mit-b0
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512

View File

@ -1,752 +0,0 @@
[
{
"index": 1,
"color": "#787878",
"label": "wall"
},
{
"index": 2,
"color": "#B47878",
"label": "building;edifice"
},
{
"index": 3,
"color": "#06E6E6",
"label": "sky"
},
{
"index": 4,
"color": "#503232",
"label": "floor;flooring"
},
{
"index": 5,
"color": "#04C803",
"label": "tree"
},
{
"index": 6,
"color": "#787850",
"label": "ceiling"
},
{
"index": 7,
"color": "#8C8C8C",
"label": "road;route"
},
{
"index": 8,
"color": "#CC05FF",
"label": "bed"
},
{
"index": 9,
"color": "#E6E6E6",
"label": "windowpane;window"
},
{
"index": 10,
"color": "#04FA07",
"label": "grass"
},
{
"index": 11,
"color": "#E005FF",
"label": "cabinet"
},
{
"index": 12,
"color": "#EBFF07",
"label": "sidewalk;pavement"
},
{
"index": 13,
"color": "#96053D",
"label": "person;individual;someone;somebody;mortal;soul"
},
{
"index": 14,
"color": "#787846",
"label": "earth;ground"
},
{
"index": 15,
"color": "#08FF33",
"label": "door;double;door"
},
{
"index": 16,
"color": "#FF0652",
"label": "table"
},
{
"index": 17,
"color": "#8FFF8C",
"label": "mountain;mount"
},
{
"index": 18,
"color": "#CCFF04",
"label": "plant;flora;plant;life"
},
{
"index": 19,
"color": "#FF3307",
"label": "curtain;drape;drapery;mantle;pall"
},
{
"index": 20,
"color": "#CC4603",
"label": "chair"
},
{
"index": 21,
"color": "#0066C8",
"label": "car;auto;automobile;machine;motorcar"
},
{
"index": 22,
"color": "#3DE6FA",
"label": "water"
},
{
"index": 23,
"color": "#FF0633",
"label": "painting;picture"
},
{
"index": 24,
"color": "#0B66FF",
"label": "sofa;couch;lounge"
},
{
"index": 25,
"color": "#FF0747",
"label": "shelf"
},
{
"index": 26,
"color": "#FF09E0",
"label": "house"
},
{
"index": 27,
"color": "#0907E6",
"label": "sea"
},
{
"index": 28,
"color": "#DCDCDC",
"label": "mirror"
},
{
"index": 29,
"color": "#FF095C",
"label": "rug;carpet;carpeting"
},
{
"index": 30,
"color": "#7009FF",
"label": "field"
},
{
"index": 31,
"color": "#08FFD6",
"label": "armchair"
},
{
"index": 32,
"color": "#07FFE0",
"label": "seat"
},
{
"index": 33,
"color": "#FFB806",
"label": "fence;fencing"
},
{
"index": 34,
"color": "#0AFF47",
"label": "desk"
},
{
"index": 35,
"color": "#FF290A",
"label": "rock;stone"
},
{
"index": 36,
"color": "#07FFFF",
"label": "wardrobe;closet;press"
},
{
"index": 37,
"color": "#E0FF08",
"label": "lamp"
},
{
"index": 38,
"color": "#6608FF",
"label": "bathtub;bathing;tub;bath;tub"
},
{
"index": 39,
"color": "#FF3D06",
"label": "railing;rail"
},
{
"index": 40,
"color": "#FFC207",
"label": "cushion"
},
{
"index": 41,
"color": "#FF7A08",
"label": "base;pedestal;stand"
},
{
"index": 42,
"color": "#00FF14",
"label": "box"
},
{
"index": 43,
"color": "#FF0829",
"label": "column;pillar"
},
{
"index": 44,
"color": "#FF0599",
"label": "signboard;sign"
},
{
"index": 45,
"color": "#0633FF",
"label": "chest;of;drawers;chest;bureau;dresser"
},
{
"index": 46,
"color": "#EB0CFF",
"label": "counter"
},
{
"index": 47,
"color": "#A09614",
"label": "sand"
},
{
"index": 48,
"color": "#00A3FF",
"label": "sink"
},
{
"index": 49,
"color": "#8C8C8C",
"label": "skyscraper"
},
{
"index": 50,
"color": "#FA0A0F",
"label": "fireplace;hearth;open;fireplace"
},
{
"index": 51,
"color": "#14FF00",
"label": "refrigerator;icebox"
},
{
"index": 52,
"color": "#1FFF00",
"label": "grandstand;covered;stand"
},
{
"index": 53,
"color": "#FF1F00",
"label": "path"
},
{
"index": 54,
"color": "#FFE000",
"label": "stairs;steps"
},
{
"index": 55,
"color": "#99FF00",
"label": "runway"
},
{
"index": 56,
"color": "#0000FF",
"label": "case;display;case;showcase;vitrine"
},
{
"index": 57,
"color": "#FF4700",
"label": "pool;table;billiard;table;snooker;table"
},
{
"index": 58,
"color": "#00EBFF",
"label": "pillow"
},
{
"index": 59,
"color": "#00ADFF",
"label": "screen;door;screen"
},
{
"index": 60,
"color": "#1F00FF",
"label": "stairway;staircase"
},
{
"index": 61,
"color": "#0BC8C8",
"label": "river"
},
{
"index": 62,
"color": "#FF5200",
"label": "bridge;span"
},
{
"index": 63,
"color": "#00FFF5",
"label": "bookcase"
},
{
"index": 64,
"color": "#003DFF",
"label": "blind;screen"
},
{
"index": 65,
"color": "#00FF70",
"label": "coffee;table;cocktail;table"
},
{
"index": 66,
"color": "#00FF85",
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
},
{
"index": 67,
"color": "#FF0000",
"label": "flower"
},
{
"index": 68,
"color": "#FFA300",
"label": "book"
},
{
"index": 69,
"color": "#FF6600",
"label": "hill"
},
{
"index": 70,
"color": "#C2FF00",
"label": "bench"
},
{
"index": 71,
"color": "#008FFF",
"label": "countertop"
},
{
"index": 72,
"color": "#33FF00",
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
},
{
"index": 73,
"color": "#0052FF",
"label": "palm;palm;tree"
},
{
"index": 74,
"color": "#00FF29",
"label": "kitchen;island"
},
{
"index": 75,
"color": "#00FFAD",
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
},
{
"index": 76,
"color": "#0A00FF",
"label": "swivel;chair"
},
{
"index": 77,
"color": "#ADFF00",
"label": "boat"
},
{
"index": 78,
"color": "#00FF99",
"label": "bar"
},
{
"index": 79,
"color": "#FF5C00",
"label": "arcade;machine"
},
{
"index": 80,
"color": "#FF00FF",
"label": "hovel;hut;hutch;shack;shanty"
},
{
"index": 81,
"color": "#FF00F5",
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
},
{
"index": 82,
"color": "#FF0066",
"label": "towel"
},
{
"index": 83,
"color": "#FFAD00",
"label": "light;light;source"
},
{
"index": 84,
"color": "#FF0014",
"label": "truck;motortruck"
},
{
"index": 85,
"color": "#FFB8B8",
"label": "tower"
},
{
"index": 86,
"color": "#001FFF",
"label": "chandelier;pendant;pendent"
},
{
"index": 87,
"color": "#00FF3D",
"label": "awning;sunshade;sunblind"
},
{
"index": 88,
"color": "#0047FF",
"label": "streetlight;street;lamp"
},
{
"index": 89,
"color": "#FF00CC",
"label": "booth;cubicle;stall;kiosk"
},
{
"index": 90,
"color": "#00FFC2",
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
},
{
"index": 91,
"color": "#00FF52",
"label": "airplane;aeroplane;plane"
},
{
"index": 92,
"color": "#000AFF",
"label": "dirt;track"
},
{
"index": 93,
"color": "#0070FF",
"label": "apparel;wearing;apparel;dress;clothes"
},
{
"index": 94,
"color": "#3300FF",
"label": "pole"
},
{
"index": 95,
"color": "#00C2FF",
"label": "land;ground;soil"
},
{
"index": 96,
"color": "#007AFF",
"label": "bannister;banister;balustrade;balusters;handrail"
},
{
"index": 97,
"color": "#00FFA3",
"label": "escalator;moving;staircase;moving;stairway"
},
{
"index": 98,
"color": "#FF9900",
"label": "ottoman;pouf;pouffe;puff;hassock"
},
{
"index": 99,
"color": "#00FF0A",
"label": "bottle"
},
{
"index": 100,
"color": "#FF7000",
"label": "buffet;counter;sideboard"
},
{
"index": 101,
"color": "#8FFF00",
"label": "poster;posting;placard;notice;bill;card"
},
{
"index": 102,
"color": "#5200FF",
"label": "stage"
},
{
"index": 103,
"color": "#A3FF00",
"label": "van"
},
{
"index": 104,
"color": "#FFEB00",
"label": "ship"
},
{
"index": 105,
"color": "#08B8AA",
"label": "fountain"
},
{
"index": 106,
"color": "#8500FF",
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
},
{
"index": 107,
"color": "#00FF5C",
"label": "canopy"
},
{
"index": 108,
"color": "#B800FF",
"label": "washer;automatic;washer;washing;machine"
},
{
"index": 109,
"color": "#FF001F",
"label": "plaything;toy"
},
{
"index": 110,
"color": "#00B8FF",
"label": "swimming;pool;swimming;bath;natatorium"
},
{
"index": 111,
"color": "#00D6FF",
"label": "stool"
},
{
"index": 112,
"color": "#FF0070",
"label": "barrel;cask"
},
{
"index": 113,
"color": "#5CFF00",
"label": "basket;handbasket"
},
{
"index": 114,
"color": "#00E0FF",
"label": "waterfall;falls"
},
{
"index": 115,
"color": "#70E0FF",
"label": "tent;collapsible;shelter"
},
{
"index": 116,
"color": "#46B8A0",
"label": "bag"
},
{
"index": 117,
"color": "#A300FF",
"label": "minibike;motorbike"
},
{
"index": 118,
"color": "#9900FF",
"label": "cradle"
},
{
"index": 119,
"color": "#47FF00",
"label": "oven"
},
{
"index": 120,
"color": "#FF00A3",
"label": "ball"
},
{
"index": 121,
"color": "#FFCC00",
"label": "food;solid;food"
},
{
"index": 122,
"color": "#FF008F",
"label": "step;stair"
},
{
"index": 123,
"color": "#00FFEB",
"label": "tank;storage;tank"
},
{
"index": 124,
"color": "#85FF00",
"label": "trade;name;brand;name;brand;marque"
},
{
"index": 125,
"color": "#FF00EB",
"label": "microwave;microwave;oven"
},
{
"index": 126,
"color": "#F500FF",
"label": "pot;flowerpot"
},
{
"index": 127,
"color": "#FF007A",
"label": "animal;animate;being;beast;brute;creature;fauna"
},
{
"index": 128,
"color": "#FFF500",
"label": "bicycle;bike;wheel;cycle"
},
{
"index": 129,
"color": "#0ABED4",
"label": "lake"
},
{
"index": 130,
"color": "#D6FF00",
"label": "dishwasher;dish;washer;dishwashing;machine"
},
{
"index": 131,
"color": "#00CCFF",
"label": "screen;silver;screen;projection;screen"
},
{
"index": 132,
"color": "#1400FF",
"label": "blanket;cover"
},
{
"index": 133,
"color": "#FFFF00",
"label": "sculpture"
},
{
"index": 134,
"color": "#0099FF",
"label": "hood;exhaust;hood"
},
{
"index": 135,
"color": "#0029FF",
"label": "sconce"
},
{
"index": 136,
"color": "#00FFCC",
"label": "vase"
},
{
"index": 137,
"color": "#2900FF",
"label": "traffic;light;traffic;signal;stoplight"
},
{
"index": 138,
"color": "#29FF00",
"label": "tray"
},
{
"index": 139,
"color": "#AD00FF",
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
},
{
"index": 140,
"color": "#00F5FF",
"label": "fan"
},
{
"index": 141,
"color": "#4700FF",
"label": "pier;wharf;wharfage;dock"
},
{
"index": 142,
"color": "#7A00FF",
"label": "crt;screen"
},
{
"index": 143,
"color": "#00FFB8",
"label": "plate"
},
{
"index": 144,
"color": "#005CFF",
"label": "monitor;monitoring;device"
},
{
"index": 145,
"color": "#B8FF00",
"label": "bulletin;board;notice;board"
},
{
"index": 146,
"color": "#0085FF",
"label": "shower"
},
{
"index": 147,
"color": "#FFD600",
"label": "radiator"
},
{
"index": 148,
"color": "#19C2C2",
"label": "glass;drinking;glass"
},
{
"index": 149,
"color": "#66FF00",
"label": "clock"
},
{
"index": 150,
"color": "#5C00FF",
"label": "flag"
}
]

View File

@ -1,155 +0,0 @@
use candle::Device;
use candle::Module;
use candle_nn::VarBuilder;
use candle_transformers::models::segformer::{
Config, ImageClassificationModel, SemanticSegmentationModel,
};
use clap::{Args, Parser, Subcommand};
use image::Rgb;
use imageproc::integral_image::ArrayData;
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Parser)]
#[clap(about, version, long_about = None)]
struct CliArgs {
#[arg(long, help = "use cpu")]
cpu: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Args, Debug)]
struct SegmentationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
)]
model_name: String,
#[arg(
long,
help = "path to the label file in json format",
default_value = "candle-examples/examples/segformer/assets/labels.json"
)]
label_path: PathBuf,
#[arg(long, help = "path to for the output mask image")]
output_path: PathBuf,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Args, Debug)]
struct ClassificationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "paolinox/segformer-finetuned-food101"
)]
model_name: String,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Subcommand, Debug)]
enum Commands {
Segment(SegmentationArgs),
Classify(ClassificationArgs),
}
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
println!("loading model {} via huggingface hub", model_name);
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name.clone());
let model_file = api.get("model.safetensors")?;
println!("model {} downloaded and loaded", model_name);
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
let config = std::fs::read_to_string(api.get("config.json")?)?;
let config: Config = serde_json::from_str(&config)?;
println!("{:?}", config);
Ok((vb, config))
}
#[derive(Debug, serde::Deserialize)]
struct LabelItem {
index: u32,
color: String,
}
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
let label_file = std::fs::read_to_string(&args.label_path)?;
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
let label_colors: HashMap<u32, Rgb<u8>> = label_items
.iter()
.map(|x| {
(x.index - 1, {
let color = x.color.trim_start_matches('#');
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
Rgb([r, g, b])
})
})
.collect();
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = label_items.len();
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
let segmentations = model.forward(&image)?;
// generate a mask image
let mask = &segmentations.squeeze(0)?.argmax(0)?;
let (h, w) = mask.dims2()?;
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
let mask = mask
.iter()
.flat_map(|x| label_colors[x].data())
.collect::<Vec<u8>>();
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
// resize
let mask = image::DynamicImage::from(mask);
let mask = mask.resize_to_fill(
w as u32 * 4,
h as u32 * 4,
image::imageops::FilterType::CatmullRom,
);
mask.save(args.output_path.clone())?;
println!("mask image saved to {:?}", args.output_path);
Ok(())
}
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = 7;
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
let classification = model.forward(&image)?;
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
let classification = classification.squeeze(0)?;
println!(
"classification logits {:?}",
classification.to_vec1::<f32>()?
);
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
let label_id = format!("{}", label_id);
println!("label: {}", config.id2label[&label_id]);
Ok(())
}
pub fn main() -> anyhow::Result<()> {
let args = CliArgs::parse();
let device = candle_examples::device(args.cpu)?;
if let Commands::Segment(args) = args.command {
segmentation_task(args, &device)?
} else if let Commands::Classify(args) = args.command {
classification_task(args, &device)?
}
Ok(())
}

View File

@ -57,7 +57,7 @@ The downside is some long compilation time. You can set the
`/home/user/.candle` to ensures that the compilation artifacts are properly
cached.
Enabling flash-attention requires both a feature flag, `--features flash-attn`
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
and using the command line flag `--use-flash-attn`.
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs

View File

@ -1,253 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::starcoder2::Model;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => "bigcode/starcoder2-3b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let config_file = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let tokenizer_file = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => vec![repo.get("model.safetensors")?],
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,29 +0,0 @@
use candle::{Result, Tensor};
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
pub fn normalize_loudness(
wav: &Tensor,
sample_rate: u32,
loudness_compressor: bool,
) -> Result<Tensor> {
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
if energy < 2e-3 {
return Ok(wav.clone());
}
let wav_array = wav.to_vec1::<f32>()?;
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
meter.push(wav_array.into_iter());
let power = meter.as_100ms_windows();
let loudness = match crate::bs1770::gated_mean(power) {
None => return Ok(wav.clone()),
Some(gp) => gp.loudness_lkfs() as f64,
};
let delta_loudness = -14. - loudness;
let gain = 10f64.powf(delta_loudness / 20.);
let wav = (wav * gain)?;
if loudness_compressor {
wav.tanh()
} else {
Ok(wav)
}
}

View File

@ -1,506 +0,0 @@
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
// Copyright 2020 Ruud van Asseldonk
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// A copy of the License has been included in the root of the repository.
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
//!
//! This library offers the building blocks to perform BS.1770 loudness
//! measurements, but you need to put the pieces together yourself.
//!
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
//!
//! # Stereo integrated loudness example
//!
//! ```ignore
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
//! # [vec![0; 48_000], vec![0; 48_000]]
//! # }
//! #
//! let sample_rate_hz = 44_100;
//! let bits_per_sample = 16;
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
//!
//! // When converting integer samples to float, note that the maximum amplitude
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
//!
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
//! meter.into_100ms_windows()
//! }).collect();
//!
//! let stereo_power = bs1770::reduce_stereo(
//! channel_power[0].as_ref(),
//! channel_power[1].as_ref(),
//! );
//!
//! let gated_power = bs1770::gated_mean(
//! stereo_power.as_ref()
//! ).unwrap_or(bs1770::Power(0.0));
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
//! ```
use std::f32;
/// Coefficients for a 2nd-degree infinite impulse response filter.
///
/// Coefficient a0 is implicitly 1.0.
#[derive(Clone)]
struct Filter {
a1: f32,
a2: f32,
b0: f32,
b1: f32,
b2: f32,
// The past two input and output samples.
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl Filter {
/// Stage 1 of th BS.1770-4 pre-filter.
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let gain_db = 3.999_843_8;
let q = 0.707_175_25;
let center_hz = 1_681.974_5;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
let vh = 10.0_f32.powf(gain_db / 20.0);
let vb = vh.powf(0.499_666_78);
let a0 = 1.0 + k / q + k * k;
Filter {
b0: (vh + vb * k / q + k * k) / a0,
b1: 2.0 * (k * k - vh) / a0,
b2: (vh - vb * k / q + k * k) / a0,
a1: 2.0 * (k * k - 1.0) / a0,
a2: (1.0 - k / q + k * k) / a0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Stage 2 of th BS.1770-4 pre-filter.
pub fn high_pass(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let q = 0.500_327_05;
let center_hz = 38.135_47;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
Filter {
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
b0: 1.0,
b1: -2.0,
b2: 1.0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Feed the next input sample, get the next output sample.
#[inline(always)]
pub fn apply(&mut self, x0: f32) -> f32 {
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
- self.a1 * self.y1
- self.a2 * self.y2;
self.x2 = self.x1;
self.x1 = x0;
self.y2 = self.y1;
self.y1 = y0;
y0
}
}
/// Compensated sum, for summing many values of different orders of magnitude
/// accurately.
#[derive(Copy, Clone, PartialEq)]
struct Sum {
sum: f32,
residue: f32,
}
impl Sum {
#[inline(always)]
fn zero() -> Sum {
Sum {
sum: 0.0,
residue: 0.0,
}
}
#[inline(always)]
fn add(&mut self, x: f32) {
let sum = self.sum + (self.residue + x);
self.residue = (self.residue + x) - (sum - self.sum);
self.sum = sum;
}
}
/// The mean of the squares of the K-weighted samples in a window of time.
///
/// K-weighted power is equivalent to K-weighted loudness, the only difference
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
///
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
/// relative to Full Scale). Loudness units are related to decibels in the
/// following sense: boosting a signal that has a loudness of
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
/// bring the loudness to 0 LUFS.
///
/// K-weighting refers to a high-shelf and high-pass filter that model the
/// effect that humans perceive a certain amount of power in low frequencies to
/// be less loud than the same amount of power in higher frequencies. In this
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
///
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
/// mean square of the samples, if no input samples exceeded the full scale, the
/// power will be in the range [0.0, 1.0]. However, the power delivered by
/// multiple channels, which is a weighted sum over individual channel powers,
/// can exceed this range, because the weighted sum is not normalized.
#[derive(Copy, Clone, PartialEq, PartialOrd)]
pub struct Power(pub f32);
impl Power {
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
///
/// This is the inverse of `loudness_lkfs`.
pub fn from_lkfs(lkfs: f32) -> Power {
// The inverse of the formula below.
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
}
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
///
/// This is the inverse of `from_lkfs`.
pub fn loudness_lkfs(&self) -> f32 {
// Equation 2 (p.5) of BS.1770-4.
-0.691 + 10.0 * self.0.log10()
}
}
/// A `T` value for non-overlapping windows of audio, 100ms in length.
///
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
/// for non-overlapping windows of 100ms duration.
///
/// These non-overlapping 100ms windows can later be combined into overlapping
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
/// to perform a gated measurement, or they can be combined into even larger
/// windows for a momentary loudness measurement.
#[derive(Copy, Clone, Debug)]
pub struct Windows100ms<T> {
pub inner: T,
}
impl<T> Windows100ms<T> {
/// Wrap a new empty vector.
pub fn new() -> Windows100ms<Vec<T>> {
Windows100ms { inner: Vec::new() }
}
/// Apply `as_ref` to the inner value.
pub fn as_ref(&self) -> Windows100ms<&[Power]>
where
T: AsRef<[Power]>,
{
Windows100ms {
inner: self.inner.as_ref(),
}
}
/// Apply `as_mut` to the inner value.
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
where
T: AsMut<[Power]>,
{
Windows100ms {
inner: self.inner.as_mut(),
}
}
#[allow(clippy::len_without_is_empty)]
/// Apply `len` to the inner value.
pub fn len(&self) -> usize
where
T: AsRef<[Power]>,
{
self.inner.as_ref().len()
}
}
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
///
/// # Output
///
/// The output of the meter is an intermediate result in the form of power for
/// 100ms non-overlapping windows. The windows need to be processed further to
/// get one of the instantaneous, momentary, and integrated loudness
/// measurements defined in BS.1770.
///
/// The windows can also be inspected directly; the data is meaningful
/// on its own (the K-weighted power delivered in that window of time), but it
/// is not something that BS.1770 defines a term for.
///
/// # Multichannel audio
///
/// To perform a loudness measurement of multichannel audio, construct a
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
/// with e.g. `reduce_stereo`.
///
/// # Instantaneous loudness
///
/// The instantaneous loudness is the power over a 400ms window, so you can
/// average four 100ms windows. No special functionality is implemented to help
/// with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Momentary loudness
///
/// The momentary loudness is the power over a 3-second window, so you can
/// average thirty 100ms windows. No special functionality is implemented to
/// help with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Integrated loudness
///
/// Use `gated_mean` to perform an integrated loudness measurement:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
/// # let sample_rate_hz = 44_100;
/// # let samples_per_100ms = sample_rate_hz / 10;
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
/// .unwrap_or(bs1770::Power(0.0))
/// .loudness_lkfs();
/// ```
///
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
#[derive(Clone)]
pub struct ChannelLoudnessMeter {
/// The number of samples that fit in 100ms of audio.
samples_per_100ms: u32,
/// Stage 1 filter (head effects, high shelf).
filter_stage1: Filter,
/// Stage 2 filter (high-pass).
filter_stage2: Filter,
/// Sum of the squares over non-overlapping windows of 100ms.
windows: Windows100ms<Vec<Power>>,
/// The number of samples in the current unfinished window.
count: u32,
/// The sum of the squares of the samples in the current unfinished window.
square_sum: Sum,
}
impl ChannelLoudnessMeter {
/// Construct a new loudness meter for the given sample rate.
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
ChannelLoudnessMeter {
samples_per_100ms: sample_rate_hz / 10,
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
windows: Windows100ms::new(),
count: 0,
square_sum: Sum::zero(),
}
}
/// Feed input samples for loudness analysis.
///
/// # Full scale
///
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
/// input consists of signed integer samples, you can convert as follows:
///
/// ```ignore
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
/// # let bits_per_sample = 16_usize;
/// # let samples = &[0_i16];
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
/// // one bit is the sign bit.
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
/// ```
///
/// # Repeated calls
///
/// You can call `push` multiple times to feed multiple batches of samples.
/// This is equivalent to feeding a single chained iterator. The leftover of
/// samples that did not fill a full 100ms window is not discarded:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::ChannelLoudnessMeter;
/// let sample_rate_hz = 44_100;
/// let samples_per_100ms = sample_rate_hz / 10;
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
///
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
/// assert_eq!(meter.as_100ms_windows().len(), 0);
///
/// meter.push(iter::once(0.0));
/// assert_eq!(meter.as_100ms_windows().len(), 1);
/// ```
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
let normalizer = 1.0 / self.samples_per_100ms as f32;
// LLVM, if you could go ahead and inline those apply calls, and then
// unroll and vectorize the loop, that'd be terrific.
for x in samples {
let y = self.filter_stage1.apply(x);
let z = self.filter_stage2.apply(y);
self.square_sum.add(z * z);
self.count += 1;
// TODO: Should this branch be marked cold?
if self.count == self.samples_per_100ms {
let mean_squares = Power(self.square_sum.sum * normalizer);
self.windows.inner.push(mean_squares);
// We intentionally do not reset the residue. That way, leftover
// energy from this window is not lost, so for the file overall,
// the sum remains more accurate.
self.square_sum.sum = 0.0;
self.count = 0;
}
}
}
/// Return a reference to the 100ms windows analyzed so far.
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
self.windows.as_ref()
}
/// Return all 100ms windows analyzed so far.
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
self.windows
}
}
/// Combine power for multiple channels by taking a weighted sum.
///
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
/// sum over channels which is not normalized. This means that a stereo signal
/// is inherently louder than a mono signal. For a mono signal played back on
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
/// in the same signal for both channels.
pub fn reduce_stereo(
left: Windows100ms<&[Power]>,
right: Windows100ms<&[Power]>,
) -> Windows100ms<Vec<Power>> {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
let mut result = Vec::with_capacity(left.len());
for (l, r) in left.inner.iter().zip(right.inner) {
result.push(Power(l.0 + r.0));
}
Windows100ms { inner: result }
}
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
for (l, r) in left.inner.iter_mut().zip(right.inner) {
l.0 += r.0;
}
}
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
///
/// The integrated loudness measurement is not just the average power over the
/// entire signal. BS.1770-4 defines defines two stages of gating that exclude
/// parts of the signal, to ensure that silent parts do not contribute to the
/// loudness measurment. This function performs that gating, and returns the
/// average power over the windows that were not excluded.
///
/// The result of this function is the integrated loudness measurement.
///
/// When no signal remains after applying the gate, this function returns
/// `None`. In particular, this happens when all of the signal is softer than
/// -70 LKFS, including a signal that consists of pure silence.
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
let absolute_threshold = Power::from_lkfs(-70.0);
// Iterate over all 400ms windows.
for window in windows_100ms.inner.windows(4) {
// Note that the sum over channels has already been performed at this point.
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
if gating_block_power > absolute_threshold {
gating_blocks.push(gating_block_power);
}
}
if gating_blocks.is_empty() {
return None;
}
// Compute the loudness after applying the absolute gate, in order to
// determine the threshold for the relative gate.
let mut sum_power = Sum::zero();
for &gating_block_power in &gating_blocks {
sum_power.add(gating_block_power.0);
}
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
// Stage 2: Apply the relative gate.
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
let mut sum_power = Sum::zero();
let mut n_blocks = 0_usize;
for &gating_block_power in &gating_blocks {
if gating_block_power > relative_threshold {
sum_power.add(gating_block_power.0);
n_blocks += 1;
}
}
if n_blocks == 0 {
return None;
}
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
Some(relative_gated_power)
}

View File

@ -1,9 +1,6 @@
pub mod audio;
pub mod bs1770;
pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
pub mod wav;
use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor};

View File

@ -40,7 +40,7 @@ impl TokenOutputStream {
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() {
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();

View File

@ -1,56 +0,0 @@
use std::io::prelude::*;
pub trait Sample {
fn to_i16(&self) -> i16;
}
impl Sample for f32 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for f64 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for i16 {
fn to_i16(&self) -> i16 {
*self
}
}
pub fn write_pcm_as_wav<W: Write, S: Sample>(
w: &mut W,
samples: &[S],
sample_rate: u32,
) -> std::io::Result<()> {
let len = 12u32; // header
let len = len + 24u32; // fmt
let len = len + samples.len() as u32 * 2 + 8; // data
let n_channels = 1u16;
let bytes_per_second = sample_rate * 2 * n_channels as u32;
w.write_all(b"RIFF")?;
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
w.write_all(b"WAVE")?;
// Format block
w.write_all(b"fmt ")?;
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
w.write_all(&1u16.to_le_bytes())?; // PCM
w.write_all(&n_channels.to_le_bytes())?; // one channel
w.write_all(&sample_rate.to_le_bytes())?;
w.write_all(&bytes_per_second.to_le_bytes())?;
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
w.write_all(&16u16.to_le_bytes())?; // bits per sample
// Data block
w.write_all(b"data")?;
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
for sample in samples.iter() {
w.write_all(&sample.to_i16().to_le_bytes())?
}
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.4.1"
version = "0.4.0"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.1" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.0" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.4.1"
version = "0.4.0"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -4,7 +4,6 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.4.1"
version = "0.4.0"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -5,7 +5,6 @@ use serde::Deserialize;
#[serde(rename_all = "lowercase")]
pub enum Activation {
#[default]
#[serde(alias = "gelu")]
Gelu,
#[serde(alias = "gelu_new")]
NewGelu,
@ -20,8 +19,6 @@ pub enum Activation {
HardSwish,
Elu(f64),
LeakyRelu(f64),
#[serde(alias = "gelu_pytorch_tanh")]
GeluPytorchTanh,
}
impl super::Module for Activation {
@ -41,7 +38,6 @@ impl super::Module for Activation {
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
Self::GeluPytorchTanh => xs.gelu(),
}
}
}

View File

@ -76,7 +76,7 @@ pub struct ConvTranspose1dConfig {
pub output_padding: usize,
pub stride: usize,
pub dilation: usize,
pub groups: usize,
// TODO: support groups.
}
impl Default for ConvTranspose1dConfig {
@ -86,7 +86,6 @@ impl Default for ConvTranspose1dConfig {
output_padding: 0,
stride: 1,
dilation: 1,
groups: 1,
}
}
}
@ -110,14 +109,6 @@ impl ConvTranspose1d {
pub fn config(&self) -> &ConvTranspose1dConfig {
&self.config
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl crate::Module for ConvTranspose1d {
@ -128,13 +119,12 @@ impl crate::Module for ConvTranspose1d {
self.config.output_padding,
self.config.stride,
self.config.dilation,
self.config.groups,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?;
let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
@ -268,14 +258,6 @@ impl ConvTranspose2d {
pub fn config(&self) -> &ConvTranspose2dConfig {
&self.config
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl crate::Module for ConvTranspose2d {
@ -320,22 +302,6 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg))
}
pub fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv1dConfig,
vb: crate::VarBuilder,
) -> Result<Conv1d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
init_ws,
)?;
Ok(Conv1d::new(ws, None, cfg))
}
pub fn conv_transpose1d(
in_channels: usize,
out_channels: usize,
@ -348,11 +314,7 @@ pub fn conv_transpose1d(
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints(
(in_channels, out_channels / cfg.groups, kernel_size),
"weight",
init,
)?;
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
let bs = vb.get_with_hints(out_channels, "bias", init)?;
Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
}
@ -369,11 +331,7 @@ pub fn conv_transpose1d_no_bias(
lo: -bound,
up: bound,
};
let ws = vb.get_with_hints(
(in_channels, out_channels / cfg.groups, kernel_size),
"weight",
init,
)?;
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
Ok(ConvTranspose1d::new(ws, None, cfg))
}

View File

@ -19,16 +19,15 @@ pub mod var_map;
pub use activation::{prelu, Activation, PReLU};
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use conv::{
conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias,
conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig,
ConvTranspose1d, ConvTranspose1dConfig, ConvTranspose2d, ConvTranspose2dConfig,
conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
};
pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};

View File

@ -57,34 +57,21 @@ impl super::Module for Linear {
/// Create or initialize a new linear layer.
///
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = crate::Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
let bs = vs.get_with_hints(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}
/// Create or initialize a new linear layer without biases.
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
Ok(Linear::new(ws, None))
}
pub fn linear_b(
in_dim: usize,
out_dim: usize,
bias: bool,
vb: crate::VarBuilder,
) -> Result<Linear> {
if bias {
linear(in_dim, out_dim, vb)
} else {
linear_no_bias(in_dim, out_dim, vb)
}
}

View File

@ -197,7 +197,7 @@ impl RNN for LSTM {
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
Tensor::stack(&states, 1)
Tensor::cat(&states, 1)
}
}

View File

@ -70,7 +70,7 @@ impl VarMap {
///
/// If an error is returned, some of the variables might have already been set to their new
/// values.
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<str>, V: AsRef<Tensor>>(
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
&mut self,
iter: I,
) -> Result<()> {

View File

@ -7,7 +7,7 @@ extern crate accelerate_src;
use candle::test_utils::{to_vec0_round, to_vec2_round};
use anyhow::Result;
use candle::{DType, Device, Tensor, Var};
use candle::{Device, Tensor, Var};
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
#[test]
@ -121,40 +121,3 @@ fn adamw_linear_regression() -> Result<()> {
assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873);
Ok(())
}
#[test]
fn adamw_linear_regression_varmap() -> Result<()> {
use candle_nn::Init::Const;
// Similar as the previous test but using a VarMap.
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
let gen = Linear::new(w_gen, Some(b_gen));
let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
let sample_ys = gen.forward(&sample_xs)?;
let mut var_map = candle_nn::VarMap::new();
let w = var_map.get((1, 2), "w", Const(0.), DType::F32, &Device::Cpu)?;
let b = var_map.get((), "b", Const(0.), DType::F32, &Device::Cpu)?;
let params = ParamsAdamW {
lr: 0.1,
..Default::default()
};
let mut opt = AdamW::new(var_map.all_vars(), params)?;
let lin = Linear::new(w, Some(b));
for _step in 0..100 {
let ys = lin.forward(&sample_xs)?;
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
opt.backward_step(&loss)?;
}
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[2.7257, 0.7097]]);
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 0.7873);
var_map.set([("w", Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?)].into_iter())?;
var_map.set([("b", Tensor::ones((), DType::F32, &Device::Cpu)?)].into_iter())?;
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[0., 0.]]);
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 1.);
Ok(())
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.4.1"
version = "0.4.0"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.4.1" }
candle-nn = { path = "../candle-nn", version = "0.4.1" }
candle = { path = "../candle-core", package = "candle-core", version = "0.4.0" }
candle-nn = { path = "../candle-nn", version = "0.4.0" }
prost = "0.12.1"
[build-dependencies]

View File

@ -832,49 +832,7 @@ fn test_flatten_operation() -> Result<()> {
// #[test]
// "Shape"
#[test]
fn test_shape_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Shape".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec1::<i64>()?;
assert_eq!(results, vec![2, 2]);
Ok(())
}
// #[test]
// "Conv"
// #[test]
@ -883,452 +841,31 @@ fn test_shape_operation() -> Result<()> {
// #[test]
// "Abs"
#[test]
fn test_abs_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Abs".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![-1.0f32, 2.0f32, -3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
Ok(())
}
// #[test]
// "Cos"
#[test]
fn test_cos_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Cos".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
);
Ok(())
}
// #[test]
// "Sin"
#[test]
fn test_sin_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Sin".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
Ok(())
}
// #[test]
// "Neg"
#[test]
fn test_neg_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Neg".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![-1.0, -2.0], vec![-3.0, -4.0]]);
Ok(())
}
// #[test]
// "Erf"
// #[test]
// "Tanh"
#[test]
fn test_tanh_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Tanh".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.0, 0.7615942], vec![0.9640276, 0.9950548]]
);
Ok(())
}
// #[test]
// "Sigmoid"
#[test]
fn test_sigmoid_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Sigmoid".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.5, 0.7310586], vec![0.880797, 0.95257413]]
);
Ok(())
}
// #[test]
// "Gelu"
#[test]
fn test_gelu_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Gelu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.0, 0.8413448], vec![1.9544997, 2.9959502]]
);
Ok(())
}
// #[test]
// "Relu"
#[test]
fn test_relu_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Relu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![0.0, 1.0], vec![0.0, 3.0]]);
Ok(())
}
// #[test]
// "Constant"
// #[test]

View File

@ -15,7 +15,6 @@ byteorder = { workspace = true }
candle = { workspace = true }
candle-flash-attn = { workspace = true, optional = true }
candle-nn = { workspace = true }
fancy-regex = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rand = { workspace = true }

View File

@ -1,5 +1,15 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = if bias {
Some(vb.get(size2, "bias")?)
} else {
None
};
Ok(Linear::new(weight, bias))
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;

View File

@ -1,4 +1,4 @@
use crate::models::with_tracing::{linear_b as linear, Linear};
use crate::models::with_tracing::Linear;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
@ -51,6 +51,14 @@ impl Config {
}
}
fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
if bias {
crate::models::with_tracing::linear(in_dim, out_dim, vb)
} else {
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
cache: Tensor,

View File

@ -1,460 +0,0 @@
//! EfficientViT (MSRA) inference implementation based on timm.
//!
//! See "EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"
//! https://arxiv.org/abs/2305.07027
//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py
use candle::{Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, ops::softmax, Conv2dConfig, Func,
VarBuilder,
};
#[derive(Clone)]
pub struct Config {
channels: [usize; 3],
blocks: [usize; 3],
heads: [usize; 3],
kernels: [usize; 4],
}
impl Config {
pub fn m0() -> Self {
Self {
channels: [64, 128, 192],
blocks: [1, 2, 3],
heads: [4, 4, 4],
kernels: [5, 5, 5, 5],
}
}
pub fn m1() -> Self {
Self {
channels: [128, 144, 192],
blocks: [1, 2, 3],
heads: [2, 3, 3],
kernels: [7, 5, 3, 3],
}
}
pub fn m2() -> Self {
Self {
channels: [128, 192, 224],
blocks: [1, 2, 3],
heads: [4, 3, 2],
kernels: [7, 5, 3, 3],
}
}
pub fn m3() -> Self {
Self {
channels: [128, 240, 320],
blocks: [1, 2, 3],
heads: [4, 3, 4],
kernels: [5, 5, 5, 5],
}
}
pub fn m4() -> Self {
Self {
channels: [128, 256, 384],
blocks: [1, 2, 3],
heads: [4, 4, 4],
kernels: [7, 5, 3, 3],
}
}
pub fn m5() -> Self {
Self {
channels: [192, 288, 384],
blocks: [1, 3, 4],
heads: [3, 3, 4],
kernels: [7, 5, 3, 3],
}
}
}
fn efficientvit_stemblock(
in_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride: 2,
padding: 1,
..Default::default()
};
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
let conv = conv2d_no_bias(in_channels, out_channels, 3, conv2d_cfg, vb.pp("conv"))?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&conv)?.apply_t(&bn, false)?;
Ok(xs)
}))
}
fn efficientvit_stem(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv1 = efficientvit_stemblock(3, dim / 8, vb.pp("conv1"))?;
let conv2 = efficientvit_stemblock(dim / 8, dim / 4, vb.pp("conv2"))?;
let conv3 = efficientvit_stemblock(dim / 4, dim / 2, vb.pp("conv3"))?;
let conv4 = efficientvit_stemblock(dim / 2, dim, vb.pp("conv4"))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&conv1)?
.relu()?
.apply(&conv2)?
.relu()?
.apply(&conv3)?
.relu()?
.apply(&conv4)?;
Ok(xs)
}))
}
fn depthwise_conv(
channels: usize,
kernel: usize,
stride: usize,
padding: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride,
padding,
groups: channels,
..Default::default()
};
let bn = batch_norm(channels, 1e-5, vb.pp("bn"))?;
let conv = conv2d_no_bias(channels, channels, kernel, conv2d_cfg, vb.pp("conv"))?;
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
}
fn pointwise_conv(
in_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
let conv = conv2d_no_bias(in_channels, out_channels, 1, conv2d_cfg, vb.pp("conv"))?;
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
}
fn conv_mlp(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
let pw1 = pointwise_conv(in_channels, out_channels, vb.pp("pw1"))?;
let pw2 = pointwise_conv(out_channels, in_channels, vb.pp("pw2"))?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&pw1)?.relu()?.apply(&pw2)?;
Ok(xs)
}))
}
// Fixed per-stage resolutions
const RESOLUTIONS: [usize; 3] = [14, 7, 4];
// Attention block
fn efficientvit_attn(
cfg: &Config,
stage: usize,
in_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let cga = cascaded_group_attn(cfg, stage, in_channels, vb)?;
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
let (b, c, h, w) = xs.dims4()?;
let win_res = 7; // Fixed window resolution
let pad_b = (win_res - h % win_res) % win_res;
let pad_r = (win_res - w % win_res) % win_res;
let ph = h + pad_b;
let pw = w + pad_r;
let nh = ph / win_res;
let nw = pw / win_res;
if RESOLUTIONS[stage] > win_res {
xs = xs.permute((0, 2, 3, 1))?;
xs = xs.pad_with_zeros(D::Minus1, 0, pad_r)?;
xs = xs.pad_with_zeros(D::Minus2, 0, pad_b)?;
xs = xs
.reshape((b, nh, win_res, nw, win_res, c))?
.transpose(2, 3)?;
xs = xs
.reshape((b * nh * nw, win_res, win_res, c))?
.permute((0, 3, 1, 2))?;
}
xs = xs.apply(&cga)?;
if RESOLUTIONS[stage] > win_res {
xs = xs
.permute((0, 2, 3, 1))?
.reshape((b, nh, nw, win_res, win_res, c))?;
xs = xs.transpose(2, 3)?.reshape((b, ph, pw, c))?;
xs = xs.permute((0, 3, 1, 2))?;
}
Ok(xs)
}))
}
// Cascaded group attention
fn cascaded_group_attn(
cfg: &Config,
stage: usize,
in_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let heads = cfg.heads[stage];
let key_dim = 16;
let val_dim = in_channels / heads;
let scale = (key_dim as f64).powf(-0.5);
let mut dws = Vec::with_capacity(heads);
let mut qkvs = Vec::with_capacity(heads);
for i in 0..heads {
dws.push(depthwise_conv(
key_dim,
cfg.kernels[i],
1,
cfg.kernels[i] / 2,
vb.pp(format!("dws.{i}")),
)?);
qkvs.push(pointwise_conv(
in_channels / heads,
in_channels / heads + 2 * key_dim,
vb.pp(format!("qkvs.{i}")),
)?);
}
let proj = pointwise_conv(in_channels, in_channels, vb.pp("proj.1"))?;
Ok(Func::new(move |xs| {
let (b, _, h, w) = xs.dims4()?;
let feats_in = xs.chunk(heads, 1)?;
let mut feats_out = Vec::with_capacity(heads);
let mut feat = feats_in[0].clone();
for i in 0..heads {
if i > 0 {
feat = (&feat + &feats_in[i])?;
}
feat = feat.apply(&qkvs[i])?;
let res = feat.reshape((b, (), h, w))?;
let q = res.narrow(1, 0, key_dim)?;
let k = res.narrow(1, key_dim, key_dim)?;
let v = res.narrow(1, 2 * key_dim, val_dim)?;
let q = q.apply(&dws[i])?;
let q = q.flatten_from(2)?;
let k = k.flatten_from(2)?;
let v = v.flatten_from(2)?;
let q = (q * scale)?;
let att = q.transpose(D::Minus2, D::Minus1)?.matmul(&k)?;
let att = softmax(&att, D::Minus1)?;
feat = v.matmul(&att.transpose(D::Minus2, D::Minus1)?)?;
feat = feat.reshape((b, val_dim, h, w))?;
feats_out.push(feat.clone());
}
let xs = Tensor::cat(&feats_out, 1)?;
let xs = xs.relu()?.apply(&proj)?;
Ok(xs)
}))
}
// Used by the downsampling layer
fn squeeze_and_excitation(
in_channels: usize,
squeeze_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
Ok(Func::new(move |xs| {
let residual = xs;
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
residual.broadcast_mul(&xs)
}))
}
// Used by the downsampling layer
fn patchmerge(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
let dim = in_channels;
let hid_dim = in_channels * 4;
let conv1 = pointwise_conv(dim, hid_dim, vb.pp("conv1"))?;
let conv2 = depthwise_conv(hid_dim, 3, 2, 1, vb.pp("conv2"))?;
let conv3 = pointwise_conv(hid_dim, out_channels, vb.pp("conv3"))?;
let se = squeeze_and_excitation(hid_dim, hid_dim / 4, vb.pp("se"))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&conv1)?
.relu()?
.apply(&conv2)?
.relu()?
.apply(&se)?
.apply(&conv3)?;
Ok(xs)
}))
}
// Used by the downsampling layer
fn res(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let dw = depthwise_conv(dim, 3, 1, 1, vb.pp("0.m"))?;
let mlp = conv_mlp(dim, dim * 2, vb.pp("1.m"))?;
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
xs = (&xs + &xs.apply(&dw)?)?;
xs = (&xs + &xs.apply(&mlp)?)?;
Ok(xs)
}))
}
// Downsampling
fn efficientvit_downsample(
in_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let res1 = res(in_channels, vb.pp("res1"))?;
let res2 = res(out_channels, vb.pp("res2"))?;
let patchmerge = patchmerge(in_channels, out_channels, vb.pp("patchmerge"))?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&res1)?.apply(&patchmerge)?.apply(&res2)?;
Ok(xs)
}))
}
fn efficientvit_block(
cfg: &Config,
stage: usize,
dim: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let dw0 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw0.m"))?;
let dw1 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw1.m"))?;
let ffn0 = conv_mlp(dim, dim * 2, vb.pp("ffn0.m"))?;
let ffn1 = conv_mlp(dim, dim * 2, vb.pp("ffn1.m"))?;
let attn = efficientvit_attn(cfg, stage, dim, vb.pp("mixer.m.attn"))?;
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
xs = (&xs + &xs.apply(&dw0)?)?;
xs = (&xs + &xs.apply(&ffn0)?)?;
xs = (&xs + &xs.apply(&attn)?)?;
xs = (&xs + &xs.apply(&dw1)?)?;
xs = (&xs + &xs.apply(&ffn1)?)?;
Ok(xs)
}))
}
// Each stage is made of blocks. There is a downsampling layer between stages.
fn efficientvit_stage(cfg: &Config, stage: usize, vb: VarBuilder) -> Result<Func<'static>> {
let nblocks = cfg.blocks[stage];
let mut blocks = Vec::with_capacity(nblocks + 1);
let in_channels = if stage > 0 {
cfg.channels[stage - 1]
} else {
cfg.channels[0]
};
let out_channels = cfg.channels[stage];
if stage > 0 {
blocks.push(efficientvit_downsample(
in_channels,
out_channels,
vb.pp("downsample"),
)?);
}
for i in 0..nblocks {
blocks.push(efficientvit_block(
cfg,
stage,
out_channels,
vb.pp(format!("blocks.{i}")),
)?);
}
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
for block in blocks.iter() {
xs = xs.apply(block)?
}
Ok(xs)
}))
}
// Classification head.
fn efficientvit_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = batch_norm(outputs, 1e-6, vb.pp("bn"))?;
let linear = linear(outputs, nclasses, vb.pp("linear"))?;
Ok(Func::new(move |xs| {
xs.apply_t(&norm, false)?.apply(&linear)
}))
}
// Build a efficientvit model for a given configuration.
fn efficientvit_model(
config: &Config,
nclasses: Option<usize>,
vb: VarBuilder,
) -> Result<Func<'static>> {
let cls = match nclasses {
None => None,
Some(nclasses) => {
let outputs = config.channels[2];
let head = efficientvit_head(outputs, nclasses, vb.pp("head"))?;
Some(head)
}
};
let stem_dim = config.channels[0];
let stem = efficientvit_stem(stem_dim, vb.pp("patch_embed"))?;
let vb = vb.pp("stages");
let stage1 = efficientvit_stage(config, 0, vb.pp(0))?;
let stage2 = efficientvit_stage(config, 1, vb.pp(1))?;
let stage3 = efficientvit_stage(config, 2, vb.pp(2))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&stem)?
.apply(&stage1)?
.apply(&stage2)?
.apply(&stage3)?
.mean(D::Minus2)?
.mean(D::Minus1)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.apply(cls),
}
}))
}
pub fn efficientvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
efficientvit_model(cfg, Some(nclasses), vb)
}
pub fn efficientvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
efficientvit_model(cfg, None, vb)
}

View File

@ -1,773 +0,0 @@
#![allow(unused)]
use candle::{DType, IndexOp, Layout, Module, Result, Shape, Tensor, D};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, ConvTranspose1d, VarBuilder};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
pub enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
pub enum PadMode {
Constant,
Reflect,
Replicate,
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub target_bandwidths: Vec<f64>,
pub sampling_rate: usize,
pub audio_channels: usize,
pub normalize: bool,
pub chunk_length_s: Option<usize>,
pub overlap: Option<usize>,
pub hidden_size: usize,
pub num_filters: usize,
pub num_residual_layers: usize,
pub upsampling_ratios: Vec<usize>,
pub norm_type: NormType,
pub kernel_size: usize,
pub last_kernel_size: usize,
pub residual_kernel_size: usize,
pub dilation_growth_rate: usize,
pub use_causal_conv: bool,
pub pad_mode: PadMode,
pub compress: usize,
pub num_lstm_layers: usize,
pub trim_right_ratio: f64,
pub codebook_size: usize,
pub codebook_dim: Option<usize>,
pub use_conv_shortcut: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
sampling_rate: 24_000,
audio_channels: 1,
normalize: false,
chunk_length_s: None,
overlap: None,
hidden_size: 128,
num_filters: 32,
num_residual_layers: 1,
upsampling_ratios: vec![8, 5, 4, 2],
norm_type: NormType::WeightNorm,
kernel_size: 7,
last_kernel_size: 7,
residual_kernel_size: 3,
dilation_growth_rate: 2,
use_causal_conv: true,
// This should be PadMode::Reflect which is currently unsupported in candle.
pad_mode: PadMode::Replicate,
compress: 2,
num_lstm_layers: 2,
trim_right_ratio: 1.0,
codebook_size: 1024,
codebook_dim: None,
use_conv_shortcut: true,
}
}
}
impl Config {
fn codebook_dim(&self) -> usize {
self.codebook_dim.unwrap_or(self.hidden_size)
}
fn frame_rate(&self) -> usize {
let hop_length: usize = self.upsampling_ratios.iter().product();
(self.sampling_rate + hop_length - 1) / hop_length
}
fn num_quantizers(&self) -> usize {
let num = 1000f64
* self
.target_bandwidths
.last()
.expect("empty target_bandwidths");
(num as usize) / (self.frame_rate() * 10)
}
}
fn get_extra_padding_for_conv1d(
xs: &Tensor,
k_size: usize,
stride: usize,
padding_total: usize,
) -> Result<usize> {
let len = xs.dim(D::Minus1)?;
let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
let ideal_len =
((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
Ok(ideal_len.saturating_sub(len))
}
fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
match mode {
PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
}
}
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
pub fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
fn conv_transpose1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
bias: bool,
config: candle_nn::ConvTranspose1dConfig,
vb: VarBuilder,
) -> Result<ConvTranspose1d> {
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
let weight_v = vb.get((in_c, out_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
Ok(ConvTranspose1d::new(weight, bias, config))
}
struct CodebookEncode;
impl candle::CustomOp2 for CodebookEncode {
fn name(&self) -> &'static str {
"cb"
}
fn cpu_fwd(
&self,
lhs_storage: &candle::CpuStorage,
lhs_layout: &Layout,
rhs_storage: &candle::CpuStorage,
rhs_layout: &Layout,
) -> Result<(candle::CpuStorage, Shape)> {
use rayon::prelude::*;
let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
if lhs_dim2 != rhs_dim2 {
candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
}
if lhs_dim2 == 0 {
candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
}
let lhs = match lhs_layout.contiguous_offsets() {
None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
Some((o1, o2)) => {
let slice = lhs_storage.as_slice::<f32>()?;
&slice[o1..o2]
}
};
let rhs = match rhs_layout.contiguous_offsets() {
None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
Some((o1, o2)) => {
let slice = rhs_storage.as_slice::<f32>()?;
&slice[o1..o2]
}
};
let dst = (0..lhs_dim1)
.into_par_iter()
.map(|idx1| {
let mut where_min = 0;
let mut min_dist = f32::INFINITY;
let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
for idx2 in 0..rhs_dim1 {
let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
let mut dist = 0f32;
for (a, b) in lhs.iter().zip(rhs.iter()) {
dist += (a - b) * (a - b)
}
if dist < min_dist {
min_dist = dist;
where_min = idx2;
}
}
where_min as u32
})
.collect();
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (lhs_dim1,).into()))
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
#[derive(Clone, Debug)]
pub struct EuclideanCodebook {
inited: Tensor,
cluster_size: Tensor,
embed: candle_nn::Embedding,
embed_avg: Tensor,
c2: Tensor,
}
impl EuclideanCodebook {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let inited = vb.get(1, "inited")?;
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
let embed = vb.get(e_shape, "embed")?;
let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?;
let embed_avg = vb.get(e_shape, "embed_avg")?;
Ok(Self {
inited,
cluster_size,
embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
embed_avg,
c2,
})
}
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?;
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
codes.reshape(target_shape)
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?;
codes.reshape(target_shape)
}
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.embed.forward(embed_ind)?;
Ok(quantize)
}
}
#[derive(Clone, Debug)]
pub struct VectorQuantization {
codebook: EuclideanCodebook,
}
impl VectorQuantization {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let codebook = EuclideanCodebook::new(cfg, vb.pp("codebook"))?;
Ok(Self { codebook })
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.transpose(1, 2)?;
self.codebook.encode_slow(&xs)
}
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.codebook.decode(embed_ind)?;
let quantize = quantize.transpose(1, 2)?;
Ok(quantize)
}
}
#[derive(Clone, Debug)]
pub struct ResidualVectorQuantizer {
layers: Vec<VectorQuantization>,
dtype: DType,
}
impl ResidualVectorQuantizer {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = &vb.pp("layers");
let layers = (0..cfg.num_quantizers())
.map(|i| VectorQuantization::new(cfg, vb.pp(i)))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
layers,
dtype: vb.dtype(),
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut codes = Vec::with_capacity(self.layers.len());
let mut residual = xs.clone();
for layer in self.layers.iter() {
let indices = layer.encode(&residual)?;
let quantized = layer.decode(&indices)?;
residual = (residual - quantized)?;
codes.push(indices)
}
Tensor::stack(&codes, 0)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), self.dtype, codes.device())?;
let ncodes = codes.dim(0)?;
if ncodes > self.layers.len() {
candle::bail!(
"codes shape {:?} does not match the number of quantization layers {}",
codes.shape(),
self.layers.len()
)
}
for (i, layer) in self.layers.iter().take(ncodes).enumerate() {
let quantized = layer.decode(&codes.i(i)?)?;
quantized_out = quantized.broadcast_add(&quantized_out)?;
}
Ok(quantized_out)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
#[derive(Clone, Debug)]
pub struct EncodecLSTM {
layers: Vec<candle_nn::LSTM>,
}
impl EncodecLSTM {
pub fn new(dim: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = &vb.pp("lstm");
let mut layers = vec![];
for layer_idx in 0..cfg.num_lstm_layers {
let config = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
layers.push(lstm)
}
Ok(Self { layers })
}
}
impl Module for EncodecLSTM {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
// This is different from the Python transformers version as candle LSTM is batch first.
let xs = xs.t()?;
let residual = &xs;
let mut xs = xs.clone();
for layer in self.layers.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
let xs = (xs + residual)?.t()?;
Ok(xs)
}
}
#[derive(Clone, Debug)]
pub struct EncodecConvTranspose1d {
conv: ConvTranspose1d,
}
impl EncodecConvTranspose1d {
fn new(
in_c: usize,
out_c: usize,
k: usize,
stride: usize,
_cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let cfg = candle_nn::ConvTranspose1dConfig {
stride,
..Default::default()
};
let conv = conv_transpose1d_weight_norm(in_c, out_c, k, true, cfg, vb.pp("conv"))?;
Ok(Self { conv })
}
}
impl Module for EncodecConvTranspose1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.conv)
}
}
#[derive(Clone, Debug)]
pub struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
pad_mode: PadMode,
}
impl EncodecConv1d {
pub fn new(
in_c: usize,
out_c: usize,
kernel_size: usize,
stride: usize,
dilation: usize,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let conv = match cfg.norm_type {
NormType::WeightNorm => conv1d_weight_norm(
in_c,
out_c,
kernel_size,
candle_nn::Conv1dConfig {
stride,
dilation,
..Default::default()
},
vb.pp("conv"),
)?,
NormType::None | NormType::TimeGroupNorm => conv1d(
in_c,
out_c,
kernel_size,
candle_nn::Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
pad_mode: cfg.pad_mode,
})
}
}
impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _t, _c) = xs.dims3()?;
let k_size = self.conv.weight().dim(D::Minus1)?;
let conv_cfg = self.conv.config();
// Effective kernel size with dilations.
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
let padding_total = k_size - conv_cfg.stride;
let extra_padding =
get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
let xs = if self.causal {
pad1d(xs, padding_total, extra_padding, self.pad_mode)?
} else {
let padding_right = padding_total / 2;
let padding_left = padding_total - padding_right;
pad1d(
xs,
padding_left,
padding_right + extra_padding,
self.pad_mode,
)?
};
let xs = self.conv.forward(&xs)?;
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
#[derive(Clone, Debug)]
pub struct EncodecResnetBlock {
block_conv1: EncodecConv1d,
block_conv2: EncodecConv1d,
shortcut: Option<EncodecConv1d>,
}
impl EncodecResnetBlock {
pub fn new(
dim: usize,
(dilation1, dilation2): (usize, usize),
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
// TODO: Apply dilations!
layer.inc();
let block_conv1 = EncodecConv1d::new(
dim,
h,
cfg.residual_kernel_size,
1,
dilation1,
cfg,
layer.next(),
)?;
layer.inc();
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?;
let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?;
Some(conv)
} else {
None
};
Ok(Self {
block_conv1,
block_conv2,
shortcut,
})
}
}
impl Module for EncodecResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs.clone();
let xs = xs.elu(1.)?;
let xs = self.block_conv1.forward(&xs)?;
let xs = xs.elu(1.)?;
let xs = self.block_conv2.forward(&xs)?;
let xs = match &self.shortcut {
None => (xs + residual)?,
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
};
Ok(xs)
}
}
struct Layer<'a> {
vb: VarBuilder<'a>,
cnt: usize,
}
impl<'a> Layer<'a> {
fn new(vb: VarBuilder<'a>) -> Self {
Self { vb, cnt: 0 }
}
fn inc(&mut self) {
self.cnt += 1;
}
fn next(&mut self) -> VarBuilder {
let vb = self.vb.pp(&self.cnt.to_string());
self.cnt += 1;
vb
}
}
#[derive(Clone, Debug)]
pub struct Encoder {
init_conv: EncodecConv1d,
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
final_lstm: EncodecLSTM,
final_conv: EncodecConv1d,
}
impl Encoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let init_conv = EncodecConv1d::new(
cfg.audio_channels,
cfg.num_filters,
cfg.kernel_size,
1,
1,
cfg,
layer.next(),
)?;
let mut sampling_layers = vec![];
let mut scaling = 1;
for &ratio in cfg.upsampling_ratios.iter().rev() {
let current_scale = scaling * cfg.num_filters;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::new(
current_scale,
(cfg.dilation_growth_rate.pow(j), 1),
cfg,
layer.next(),
)?;
resnets.push(resnet)
}
layer.inc(); // ELU
let conv1d = EncodecConv1d::new(
current_scale,
current_scale * 2,
ratio * 2,
ratio,
1,
cfg,
layer.next(),
)?;
sampling_layers.push((resnets, conv1d));
scaling *= 2;
}
let final_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
layer.inc(); // ELU
let final_conv = EncodecConv1d::new(
cfg.num_filters * scaling,
cfg.hidden_size,
cfg.last_kernel_size,
1,
1,
cfg,
layer.next(),
)?;
Ok(Self {
init_conv,
sampling_layers,
final_conv,
final_lstm,
})
}
}
impl Module for Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?;
for (resnets, conv) in self.sampling_layers.iter() {
for resnet in resnets.iter() {
xs = xs.apply(resnet)?;
}
xs = xs.elu(1.0)?.apply(conv)?;
}
xs.apply(&self.final_lstm)?
.elu(1.0)?
.apply(&self.final_conv)
}
}
#[derive(Clone, Debug)]
pub struct Decoder {
init_conv: EncodecConv1d,
init_lstm: EncodecLSTM,
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
final_conv: EncodecConv1d,
}
impl Decoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
let init_conv = EncodecConv1d::new(
cfg.hidden_size,
cfg.num_filters * scaling,
cfg.last_kernel_size,
1,
1,
cfg,
layer.next(),
)?;
let init_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
let mut sampling_layers = vec![];
for &ratio in cfg.upsampling_ratios.iter() {
let current_scale = scaling * cfg.num_filters;
layer.inc(); // ELU
let conv1d = EncodecConvTranspose1d::new(
current_scale,
current_scale / 2,
ratio * 2,
ratio,
cfg,
layer.next(),
)?;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::new(
current_scale / 2,
(cfg.dilation_growth_rate.pow(j), 1),
cfg,
layer.next(),
)?;
resnets.push(resnet)
}
sampling_layers.push((conv1d, resnets));
scaling /= 2;
}
layer.inc(); // ELU
let final_conv = EncodecConv1d::new(
cfg.num_filters,
cfg.audio_channels,
cfg.last_kernel_size,
1,
1,
cfg,
layer.next(),
)?;
Ok(Self {
init_conv,
init_lstm,
sampling_layers,
final_conv,
})
}
}
impl Module for Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
for (conv, resnets) in self.sampling_layers.iter() {
xs = xs.elu(1.)?.apply(conv)?;
for resnet in resnets.iter() {
xs = xs.apply(resnet)?
}
}
xs.elu(1.)?.apply(&self.final_conv)
}
}
#[derive(Debug)]
pub struct Model {
encoder: Encoder,
decoder: Decoder,
quantizer: ResidualVectorQuantizer,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
let quantizer = ResidualVectorQuantizer::new(cfg, vb.pp("quantizer"))?;
Ok(Self {
encoder,
decoder,
quantizer,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.encoder.forward(xs)?;
let codes = self.quantizer.encode(&xs)?;
codes.transpose(0, 1)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let (_b_sz, _codebooks, _seqlen) = codes.dims3()?;
let codes = codes.transpose(0, 1)?;
let embeddings = self.quantizer.decode(&codes)?;
let outputs = self.decoder.forward(&embeddings)?;
Ok(outputs)
}
}

View File

@ -1,8 +1,18 @@
use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000;
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = if bias {
Some(vb.get(size2, "bias")?)
} else {
None
};
Ok(Linear::new(weight, bias))
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),

View File

@ -1,381 +0,0 @@
use std::sync::Arc;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b as linear, Linear, VarBuilder};
fn default_max_position_embeddings() -> usize {
4096
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
pub hidden_act: candle_nn::Activation,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub vocab_size: usize,
#[serde(default = "default_max_position_embeddings")]
pub max_position_embeddings: usize,
}
#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
eps: f64,
}
impl RmsNorm {
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(dim, "weight")?;
Ok(Self { weight, eps })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&(&self.weight + 1.0)?)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: candle_nn::Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = cfg.head_dim;
let bias = cfg.attention_bias;
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
rotary_emb,
kv_cache: None,
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, ()))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
device: Device,
dtype: DType,
hidden_size: usize,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
device: vb.device().clone(),
dtype: vb.dtype(),
hidden_size: cfg.hidden_size,
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -2,6 +2,7 @@ use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096;
@ -83,9 +84,10 @@ impl Config {
#[derive(Debug, Clone)]
pub struct Cache {
masks: HashMap<usize, Tensor>,
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
kvs: Vec<Option<(Tensor, Tensor)>>,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
device: Device,
@ -110,24 +112,25 @@ impl Cache {
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self {
masks: HashMap::new(),
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
kvs: vec![None; config.num_hidden_layers],
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
device: device.clone(),
cos,
sin,
})
}
fn mask(&mut self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
self.masks.insert(t, mask.clone());
masks.insert(t, mask.clone());
Ok(mask)
}
}
@ -161,6 +164,7 @@ struct CausalSelfAttention {
num_attention_heads: usize,
num_key_value_heads: usize,
head_dim: usize,
cache: Cache,
use_flash_attn: bool,
span: tracing::Span,
span_rot: tracing::Span,
@ -183,11 +187,11 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
@ -197,13 +201,7 @@ impl CausalSelfAttention {
Ok(rope)
}
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seq_len, hidden_size) = x.dims3()?;
let q = self.q_proj.forward(x)?;
@ -220,11 +218,12 @@ impl CausalSelfAttention {
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1];
@ -240,7 +239,7 @@ impl CausalSelfAttention {
.contiguous()?
}
}
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
cache[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
@ -259,7 +258,7 @@ impl CausalSelfAttention {
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
@ -284,7 +283,7 @@ impl CausalSelfAttention {
}
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size;
@ -302,6 +301,7 @@ impl CausalSelfAttention {
num_attention_heads: cfg.num_attention_heads,
num_key_value_heads: cfg.num_key_value_heads,
head_dim: cfg.hidden_size / cfg.num_attention_heads,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
@ -357,25 +357,19 @@ struct Block {
}
impl Block {
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let rms_2 = RmsNorm::load(
@ -402,11 +396,11 @@ pub struct Llama {
}
impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?;
x = block.forward(&x, index_pos, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?;
@ -414,12 +408,12 @@ impl Llama {
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.collect();
Ok(Self {

View File

@ -2,6 +2,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct Config {
@ -69,11 +70,12 @@ impl Config {
}
}
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct Cache {
masks: HashMap<usize, Tensor>,
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
pub kvs: Vec<Option<(Tensor, Tensor)>>,
#[allow(clippy::type_complexity)]
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
pub cos: Tensor,
pub sin: Tensor,
device: Device,
@ -103,24 +105,25 @@ impl Cache {
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
Ok(Self {
masks: HashMap::new(),
masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache,
kvs: vec![None; cfg.n_layers],
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
cos,
sin,
device: vb.device().clone(),
})
}
pub fn mask(&mut self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
pub fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
self.masks.insert(t, mask.clone());
masks.insert(t, mask.clone());
Ok(mask)
}
}
@ -130,7 +133,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
#[derive(Debug, Clone)]
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@ -139,13 +141,14 @@ struct CausalSelfAttention {
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
@ -159,13 +162,7 @@ impl CausalSelfAttention {
Ok(rope)
}
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
@ -175,15 +172,16 @@ impl CausalSelfAttention {
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
}
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
cache[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
@ -194,7 +192,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
@ -218,7 +216,7 @@ impl CausalSelfAttention {
}
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
@ -234,6 +232,7 @@ impl CausalSelfAttention {
n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
})
}
}
@ -245,7 +244,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}
#[derive(Debug, Clone)]
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
@ -276,7 +274,6 @@ impl Mlp {
}
}
#[derive(Debug, Clone)]
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
@ -294,23 +291,17 @@ impl Block {
}
}
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
@ -324,7 +315,6 @@ impl Block {
}
}
#[derive(Debug, Clone)]
pub struct Llama {
wte: Embedding,
blocks: Vec<Block>,
@ -334,23 +324,23 @@ pub struct Llama {
}
impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?;
x = block.forward(&x, index_pos, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap())
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect();
Ok(Self {
wte,

View File

@ -32,9 +32,9 @@ impl Config {
}
pub struct State {
pub hs: Vec<Tensor>,
pub prev_xs: Vec<[Tensor; D_CONV]>,
pub pos: usize,
hs: Vec<Tensor>,
prev_xs: Vec<[Tensor; D_CONV]>,
pos: usize,
}
impl State {

View File

@ -1,968 +0,0 @@
use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
// Equivalent to torch.repeat_interleave
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
let img = img.unsqueeze(dim + 1)?;
let mut dims = img.dims().to_vec();
dims[dim + 1] = repeats;
img.broadcast_as(dims)?.flatten(dim, dim + 1)
}
pub mod speaker_encoder {
use super::*;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub sampling_rate: usize,
pub partial_n_frames: usize,
pub model_hidden_size: usize,
pub model_embedding_size: usize,
pub model_num_layers: usize,
pub mel_window_length: usize,
pub mel_window_step: usize,
pub mel_n_channels: usize,
}
impl Config {
pub fn cfg() -> Self {
Self {
sampling_rate: 16_000,
partial_n_frames: 160,
model_hidden_size: 256,
model_embedding_size: 256,
model_num_layers: 3,
mel_window_length: 25,
mel_window_step: 10,
mel_n_channels: 40,
}
}
}
pub struct Model {
lstms: Vec<candle_nn::LSTM>,
linear: Linear,
cfg: Config,
}
type Slice = (usize, usize);
impl Model {
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
let mut lstms = Vec::with_capacity(cfg.model_num_layers);
let vb_l = vb.pp("lstm");
for layer_idx in 0..cfg.model_num_layers {
let c = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let in_c = if layer_idx == 0 {
cfg.mel_n_channels
} else {
cfg.model_hidden_size
};
let lstm = candle_nn::lstm(in_c, cfg.model_hidden_size, c, vb_l.clone())?;
lstms.push(lstm)
}
let linear = linear_b(
cfg.model_hidden_size,
cfg.model_embedding_size,
true,
vb.pp("linear"),
)?;
Ok(Self { lstms, linear, cfg })
}
fn compute_partial_slices(
&self,
n_samples: usize,
rate: f64,
min_coverage: f64,
) -> (Vec<Slice>, Vec<Slice>) {
let c = &self.cfg;
// Compute how many frames separate two partial utterances
let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
let n_frames = n_samples / samples_per_frame + 1;
let frame_step =
(c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
// Compute the slices.
let mut wav_slices = vec![];
let mut mel_slices = vec![];
for i in (0..steps).step_by(frame_step) {
let mel_range = (i, i + c.partial_n_frames);
let wav_range = (
i * samples_per_frame,
(i + c.partial_n_frames) * samples_per_frame,
);
mel_slices.push(mel_range);
wav_slices.push(wav_range);
}
// Evaluate whether extra padding is warranted or not.
let last_wav_range = match wav_slices.last() {
None => return (wav_slices, mel_slices),
Some(l) => *l,
};
let coverage = (n_samples - last_wav_range.0) as f64
/ (last_wav_range.1 - last_wav_range.0) as f64;
if coverage > min_coverage && mel_slices.len() > 1 {
mel_slices.pop();
wav_slices.pop();
}
(wav_slices, mel_slices)
}
pub fn embed_utterance(
&self,
wav: &[f32],
mel_filters: &[f32],
rate: f64,
min_c: f64,
device: &Device,
) -> Result<Tensor> {
let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
let max_wave_length = match wav_slices.last() {
Some(v) => v.1,
None => candle::bail!("empty wav slices"),
};
let wav = if max_wave_length > wav.len() {
let mut wav = wav.to_vec();
wav.resize(max_wave_length - wav.len(), 0.0);
std::borrow::Cow::Owned(wav)
} else {
std::borrow::Cow::Borrowed(wav)
};
let mel = crate::models::whisper::audio::log_mel_spectrogram_(
wav.as_ref(),
mel_filters,
/* fft_size */ self.cfg.mel_window_length,
/* fft_step */ self.cfg.mel_window_step,
self.cfg.mel_n_channels,
false,
);
let mels = mel_slices
.iter()
.flat_map(|s| [mel[s.0], mel[s.1]])
.collect::<Vec<_>>();
let mels = Tensor::from_vec(mels, (1, mel_slices.len(), 2), device)?;
let partial_embeds = self.forward(&mels)?;
let raw_embed = partial_embeds.mean(0)?;
let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
raw_embed.broadcast_div(&norm)
}
}
impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
// This is different from the Python transformers version as candle LSTM is batch first.
let xs = xs.t()?;
let mut xs = xs.clone();
for layer in self.lstms.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
let xs = xs.t()?;
let embeds_raw = xs.apply(&self.linear)?.relu()?;
let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
embeds_raw.broadcast_div(&norm)
}
}
}
type Rank = u32;
pub mod tokenizers {
use super::*;
use std::collections::HashMap;
pub struct BPE {
pub re: fancy_regex::Regex,
pub end_of_text: usize,
pub offset: usize,
pub ranks: HashMap<Vec<u8>, Rank>,
}
impl BPE {
pub fn from_json(json: &serde_json::Value, end_of_text: usize) -> Result<Self> {
let json = match json.as_object() {
None => candle::bail!("json value is not an object"),
Some(json) => json,
};
let re = match json.get("pat_str") {
None => candle::bail!("json object has no pat_str field"),
Some(pat_str) => match pat_str.as_str() {
None => candle::bail!("pat_str field is not a string"),
Some(pat_str) => fancy_regex::Regex::new(pat_str).map_err(E::wrap)?,
},
};
let offset = match json.get("offset") {
None => candle::bail!("json object has no offset field"),
Some(offset) => match offset.as_u64() {
None => candle::bail!("offset field is not a positive int"),
Some(offset) => offset as usize,
},
};
let mut ranks = HashMap::new();
for id in 0u8..=255 {
ranks.insert(vec![id], id as u32);
}
let mergeable_ranks = match json.get("mergeable_ranks") {
None => candle::bail!("json object has no mergeable_ranks field"),
Some(mr) => match mr.as_object() {
None => candle::bail!("mergeable_ranks is not an object"),
Some(mr) => mr,
},
};
for (key, value) in mergeable_ranks.iter() {
let value = match value.as_u64() {
None => candle::bail!("mergeable_ranks '{key}' is not a u64"),
Some(value) => value as u32,
};
if value < 256 {
continue;
}
// No escaping for other keys.
let key = key.as_bytes().to_vec();
ranks.insert(key, value);
}
Ok(Self {
re,
end_of_text,
offset,
ranks,
})
}
// Taken from:
// https://github.com/openai/tiktoken/blob/1b9faf2779855124f05174adf1383e53689ed94b/src/lib.rs#L16C1-L82C2
fn _byte_pair_merge(&self, piece: &[u8]) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the pair starting at position start.
let mut parts = Vec::with_capacity(piece.len() + 1);
// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let rank = *self.ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
if rank < min_rank.0 {
min_rank = (rank, i);
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));
let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, Rank)>, i: usize| {
if (i + 3) < parts.len() {
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
// parts[i + 1], see comment in the main loop.
*self
.ranks
.get(&piece[parts[i].0..parts[i + 3].0])
.unwrap_or(&Rank::MAX)
} else {
Rank::MAX
}
}
};
// If you have n parts and m merges, this does O(mn) work.
// We could do something with a heap and do O(m log n) work.
// n is often very small so considerations like cache-locality outweigh the algorithmic
// complexity downsides of the `parts` vector.
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
// `parts.remove(i + 1)` will thrash the cache.
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1);
}
parts[i].1 = get_rank(&parts, i);
parts.remove(i + 1);
min_rank = (Rank::MAX, usize::MAX);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}
}
parts
}
pub fn byte_pair_encode(&self, piece: &[u8]) -> Vec<Rank> {
if piece.is_empty() {
return Vec::new();
}
if piece.len() == 1 {
return vec![self.ranks[piece]];
}
assert!(piece.len() > 1);
self._byte_pair_merge(piece)
.windows(2)
.map(|part| self.ranks[&piece[part[0].0..part[1].0]])
.collect()
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
let mut bpe_tokens: Vec<u32> = Vec::new();
for word in self.re.find_iter(text) {
let word = word.map_err(E::wrap)?;
let word_tokens = self.byte_pair_encode(word.as_str().as_bytes());
for &token in word_tokens.iter() {
bpe_tokens.push(token + self.offset as u32)
}
}
bpe_tokens.push((self.end_of_text + self.offset) as u32);
Ok(bpe_tokens)
}
}
}
pub mod gpt {
use super::*;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum NormType {
LayerNorm,
RMSNorm,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum AttnKernelType {
Fa2,
TorchAttn,
Hand,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum NonLinearityType {
Gelu,
Swiglu,
}
enum Norm {
RMSNorm(candle_nn::RmsNorm),
LayerNorm(candle_nn::LayerNorm),
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L27
#[derive(Debug, Clone)]
pub struct Config {
pub block_size: usize,
pub vocab_sizes: Vec<usize>,
pub target_vocab_sizes: Vec<usize>,
pub n_layer: usize,
pub n_head: usize,
pub n_embd: usize,
pub bias: bool,
pub causal: bool,
pub spk_emb_on_text: bool,
pub norm_type: NormType,
pub rmsnorm_eps: f64,
pub nonlinearity_type: NonLinearityType,
pub swiglu_multiple_of: Option<usize>,
pub attn_kernel_type: AttnKernelType,
pub kv_cache_enabled: bool,
}
impl Config {
pub fn cfg1b_v0_1() -> Self {
Self {
n_layer: 6,
n_head: 6,
n_embd: 384,
block_size: 1024,
bias: false,
vocab_sizes: vec![1538, 1025],
causal: false,
target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
swiglu_multiple_of: Some(256),
norm_type: NormType::LayerNorm,
kv_cache_enabled: false,
attn_kernel_type: AttnKernelType::TorchAttn,
spk_emb_on_text: true,
nonlinearity_type: NonLinearityType::Gelu,
rmsnorm_eps: 1e-5,
}
}
}
impl Norm {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
match cfg.norm_type {
NormType::RMSNorm => {
let rms_norm = candle_nn::rms_norm(cfg.n_embd, cfg.rmsnorm_eps, vb)?;
Ok(Self::RMSNorm(rms_norm))
}
NormType::LayerNorm => {
let ln_cfg = candle_nn::LayerNormConfig {
affine: cfg.bias,
..Default::default()
};
let layer_norm = candle_nn::layer_norm(cfg.n_embd, ln_cfg, vb)?;
Ok(Self::LayerNorm(layer_norm))
}
}
}
}
impl Module for Norm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::RMSNorm(m) => m.forward(xs),
Self::LayerNorm(m) => m.forward(xs),
}
}
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/attn.py#L18
struct SelfAttention {
c_attn: Linear,
c_proj: Linear,
n_head: usize,
}
impl SelfAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
// The different attention variants are likely to be identical but still we only accept
// TorchAttn for now.
if cfg.attn_kernel_type != AttnKernelType::TorchAttn {
candle::bail!("only TorchAttn is supported")
}
if cfg.kv_cache_enabled {
candle::bail!("kv_cache_enabled=true is not supported")
}
let c_attn = linear_b(cfg.n_embd, cfg.n_embd * 3, cfg.bias, vb.pp("c_attn"))?;
let c_proj = linear_b(cfg.n_embd, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
Ok(Self {
c_attn,
c_proj,
n_head: cfg.n_head,
})
}
}
impl Module for SelfAttention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, t, c) = xs.dims3()?;
let c_x = xs
.apply(&self.c_attn)?
.reshape((b, t, 3, self.n_head, c / self.n_head))?;
let q = c_x.i((.., .., 0))?;
let k = c_x.i((.., .., 1))?;
let v = c_x.i((.., .., 2))?;
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
// TODO: causal mask
let att = candle_nn::ops::softmax_last_dim(&att)?;
let att = att.matmul(&v)?.transpose(1, 2)?;
att.reshape((b, t, c))?.apply(&self.c_proj)
}
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/layers.py#L43
#[allow(clippy::upper_case_acronyms)]
enum MLP {
Gelu {
c_fc: Linear,
c_proj: Linear,
},
Swiglu {
w1: Linear,
w3: Linear,
c_proj: Linear,
},
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_dim = 4 * cfg.n_embd;
let slf = match cfg.nonlinearity_type {
NonLinearityType::Gelu => {
let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
Self::Gelu { c_fc, c_proj }
}
NonLinearityType::Swiglu => {
let hidden_dim = (2 * hidden_dim) / 3;
let swiglu_multiple_of = match cfg.swiglu_multiple_of {
None => candle::bail!("swiglu-multiple-of has to be set"),
Some(smo) => smo,
};
let hidden_dim = swiglu_multiple_of * (hidden_dim + swiglu_multiple_of - 1)
/ swiglu_multiple_of;
let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
Self::Swiglu { w1, w3, c_proj }
}
};
Ok(slf)
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Gelu { c_fc, c_proj } => xs.apply(c_fc)?.gelu()?.apply(c_proj),
Self::Swiglu { w1, w3, c_proj } => {
let w1 = xs.apply(w1)?;
let w3 = xs.apply(w3)?;
(w1.silu()? * w3)?.apply(c_proj)
}
}
}
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/combined.py#L7
struct Block {
ln_1: Norm,
ln_2: Norm,
attn: SelfAttention,
mlp: MLP,
}
impl Block {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln_1 = Norm::new(cfg, vb.pp("ln_1"))?;
let ln_2 = Norm::new(cfg, vb.pp("ln_2"))?;
let attn = SelfAttention::new(cfg, vb.pp("attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
Ok(Block {
ln_1,
ln_2,
attn,
mlp,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
Ok(xs)
}
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L79
#[allow(clippy::upper_case_acronyms)]
pub struct Model {
wtes: Vec<candle_nn::Embedding>,
wpe: candle_nn::Embedding,
h: Vec<Block>,
ln_f: Norm,
lm_heads: Vec<Linear>,
cfg: Config,
dtype: DType,
}
impl Model {
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
let vb_t = vb.pp("transformer");
let ln_f = Norm::new(&cfg, vb_t.pp("ln_f"))?;
let mut wtes = Vec::with_capacity(cfg.vocab_sizes.len());
let vb_w = vb_t.pp("wtes");
for (idx, vocab_size) in cfg.vocab_sizes.iter().enumerate() {
let wte = candle_nn::embedding(*vocab_size, cfg.n_embd, vb_w.pp(idx))?;
wtes.push(wte)
}
let wpe = candle_nn::embedding(cfg.block_size, cfg.n_embd, vb_t.pp("wpe"))?;
let mut h = Vec::with_capacity(cfg.n_layer);
let vb_h = vb_t.pp("h");
for idx in 0..cfg.n_layer {
let block = Block::new(&cfg, vb_h.pp(idx))?;
h.push(block)
}
let mut lm_heads = Vec::with_capacity(cfg.target_vocab_sizes.len());
let vb_l = vb.pp("lm_heads");
for (idx, vocab_size) in cfg.target_vocab_sizes.iter().enumerate() {
let head = linear_b(cfg.n_embd, *vocab_size, false, vb_l.pp(idx))?;
lm_heads.push(head)
}
Ok(Self {
wtes,
wpe,
h,
ln_f,
lm_heads,
cfg,
dtype: vb.dtype(),
})
}
pub fn config(&self) -> &Config {
&self.cfg
}
pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
let device = idx.device();
let (b, _num_hierarchies, t) = idx.dims3()?;
let pos = Tensor::arange(0u32, t as u32, device)?;
let pos_emb = pos.apply(&self.wpe)?;
let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), self.dtype, device)?;
for (wte_idx, wte) in self.wtes.iter().enumerate() {
let emb = idx.i((.., wte_idx, ..))?.apply(wte)?;
tok_emb = (tok_emb + emb)?;
}
// TODO: speaker embs.
let spk_emb = 0f64;
let mut xs = (pos_emb.broadcast_add(&tok_emb)? + spk_emb)?;
for block in self.h.iter() {
xs = xs.apply(block)?
}
let xs = xs.apply(&self.ln_f)?;
let mut logits = Vec::with_capacity(self.lm_heads.len());
for lm_head in self.lm_heads.iter() {
// non-causal mode only.
let ys = xs.apply(lm_head)?;
logits.push(ys)
}
Ok(logits)
}
}
}
pub mod transformer {
use super::*;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub block_size: usize,
pub vocab_size: usize,
pub n_layer: usize,
pub n_head: usize,
pub dim: usize,
pub speaker_emb_dim: usize,
pub intermediate_size: Option<usize>,
pub n_local_heads: Option<usize>,
pub norm_eps: f64,
}
impl Config {
pub fn cfg1b_v0_1() -> Self {
Self {
n_layer: 24,
n_head: 16,
dim: 2048,
vocab_size: 2562,
speaker_emb_dim: 256,
block_size: 2048,
intermediate_size: None,
n_local_heads: None,
norm_eps: 1e-5,
}
}
fn n_local_heads(&self) -> usize {
self.n_local_heads.unwrap_or(self.n_head)
}
fn head_dim(&self) -> usize {
self.dim / self.n_head
}
fn intermediate_size(&self) -> usize {
match self.intermediate_size {
Some(intermediate_size) => intermediate_size,
None => {
let hidden_dim = self.dim * 4;
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
(n_hidden + 255) / 256 * 256
}
}
}
}
#[derive(Debug, Clone)]
struct FeedForward {
w1: Linear,
w2: Linear,
w3: Linear,
}
impl FeedForward {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let i_size = cfg.intermediate_size();
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
Ok(Self { w1, w2, w3 })
}
}
impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
swiglu.apply(&self.w2)
}
}
#[derive(Debug, Clone)]
struct Attention {
wqkv: Linear,
wo: Linear,
dim: usize,
kv_size: usize,
n_local_heads: usize,
head_dim: usize,
n_head: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let n_local_heads = cfg.n_local_heads();
let head_dim = cfg.head_dim();
let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
Ok(Self {
wqkv,
wo,
dim: cfg.dim,
kv_size: n_local_heads * head_dim,
n_local_heads,
head_dim,
n_head: cfg.n_head,
kv_cache: None,
})
}
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
let (b_sz, seqlen, _) = xs.dims3()?;
let qkv = xs.apply(&self.wqkv)?;
let q = qkv.narrow(D::Minus1, 0, self.dim)?;
let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
let q = q
.reshape((b_sz, seqlen, self.n_head, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &k], 2)?;
let v = Tensor::cat(&[prev_v, &v], 2)?;
(k, v)
}
};
self.kv_cache = Some((k.clone(), v.clone()));
let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
let attn_weights = attn_weights.broadcast_add(mask)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?;
attn_output
.transpose(1, 2)?
.reshape((b_sz, seqlen, self.dim))?
.apply(&self.wo)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct Block {
attention: Attention,
feed_forward: FeedForward,
ffn_norm: RmsNorm,
attention_norm: RmsNorm,
}
impl Block {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention = Attention::new(cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
let ffn_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
let attention_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
Ok(Self {
attention,
feed_forward,
ffn_norm,
attention_norm,
})
}
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
let hs = xs.apply(&self.attention_norm)?;
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
}
fn clear_kv_cache(&mut self) {
self.attention.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
tok_embeddings: Embedding,
pos_embeddings: Embedding,
speaker_cond_pos: Linear,
layers: Vec<Block>,
norm: RmsNorm,
output: Linear,
spk_cond_mask: Tensor,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let tok_embeddings = embedding(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
let pos_embeddings = embedding(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
let speaker_cond_pos = linear_b(
cfg.speaker_emb_dim,
cfg.dim,
false,
vb.pp("speaker_cond_pos"),
)?;
let mut layers = Vec::with_capacity(cfg.n_layer);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.n_layer {
let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
let dtype = vb.dtype();
let spk_cond_mask = Tensor::cat(
&[
Tensor::ones((1, 1, cfg.dim), dtype, vb.device())?,
Tensor::zeros((1, 1, cfg.dim), dtype, vb.device())?,
],
0,
)?;
Ok(Self {
tok_embeddings,
pos_embeddings,
speaker_cond_pos,
layers,
norm,
output,
spk_cond_mask,
})
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
let (_b_sz, seqlen) = xs.dims2()?;
let mask: Vec<_> = (0..seqlen)
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
let tok_embeddings = xs.apply(&self.tok_embeddings)?;
let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
let mut xs = tok_embeddings
.broadcast_add(&pos_embeddings)?
.broadcast_add(
&spk_emb
.apply(&self.speaker_cond_pos)?
.broadcast_mul(&self.spk_cond_mask)?,
)?;
let mask = mask.to_dtype(xs.dtype())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, pos, &mask)?
}
xs.narrow(1, seqlen - 1, 1)?
.apply(&self.norm)?
.apply(&self.output)
}
}
}
pub mod adapters {
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py
pub struct TiltedEncodec {
end_of_audio_token: u32,
}
impl TiltedEncodec {
pub fn new(end_of_audio_token: u32) -> Self {
Self { end_of_audio_token }
}
pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
let mut text_ids = vec![];
let mut extracted_audio_ids = vec![];
let mut min_audio_ids_len = usize::MAX;
for (book_id, tokens) in tokens.iter().enumerate() {
let mut audio_ids = vec![];
for &t in tokens.iter() {
#[allow(clippy::comparison_chain)]
if t > self.end_of_audio_token {
if book_id == 0 {
text_ids.push(t)
}
} else if t < self.end_of_audio_token {
audio_ids.push(t)
}
}
min_audio_ids_len = usize::min(min_audio_ids_len, audio_ids.len());
extracted_audio_ids.push(audio_ids)
}
for audio_ids in extracted_audio_ids.iter_mut() {
audio_ids.truncate(min_audio_ids_len)
}
(text_ids, extracted_audio_ids)
}
}
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4
pub struct FlattenedInterleavedEncodec2Codebook {
end_of_audio_token: u32,
}
impl FlattenedInterleavedEncodec2Codebook {
pub fn new(end_of_audio_token: u32) -> Self {
Self { end_of_audio_token }
}
pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
let mut text_ids = vec![];
let mut audio_ids1 = vec![];
let mut audio_ids2 = vec![];
for &t in tokens.iter() {
#[allow(clippy::comparison_chain)]
if t < self.end_of_audio_token {
audio_ids1.push(t)
} else if t < 2 * self.end_of_audio_token {
audio_ids2.push(t - self.end_of_audio_token)
} else {
text_ids.push(t)
}
}
(text_ids, audio_ids1, audio_ids2)
}
}
}

View File

@ -8,17 +8,13 @@ pub mod convnext;
pub mod dinov2;
pub mod distilbert;
pub mod efficientnet;
pub mod efficientvit;
pub mod encodec;
pub mod falcon;
pub mod gemma;
pub mod jina_bert;
pub mod llama;
pub mod llama2_c;
pub mod llama2_c_weights;
pub mod mamba;
pub mod marian;
pub mod metavoice;
pub mod mistral;
pub mod mixformer;
pub mod mixtral;
@ -33,24 +29,20 @@ pub mod quantized_llama2_c;
pub mod quantized_mistral;
pub mod quantized_mixformer;
pub mod quantized_mpt;
pub mod quantized_rwkv_v5;
pub mod quantized_rwkv_v6;
pub mod quantized_stable_lm;
pub mod quantized_t5;
pub mod qwen2;
pub mod repvgg;
pub mod resnet;
pub mod rwkv_v5;
pub mod rwkv_v6;
pub mod segformer;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod starcoder2;
pub mod t5;
pub mod trocr;
pub mod vgg;
pub mod vit;
pub mod vocos;
pub mod whisper;
pub mod with_tracing;
pub mod wuerstchen;

View File

@ -157,16 +157,16 @@ struct LayerWeights {
head_dim: usize,
cos: Tensor,
sin: Tensor,
neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span,
span_rot: tracing::Span,
span_mlp: tracing::Span,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
@ -240,7 +240,7 @@ impl LayerWeights {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, &self.neg_inf)?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
@ -298,7 +298,6 @@ impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
@ -338,7 +337,6 @@ impl ModelWeights {
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_rot,
@ -387,7 +385,6 @@ impl ModelWeights {
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
@ -458,7 +455,6 @@ impl ModelWeights {
head_dim: embedding_length / head_count,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_rot,

View File

@ -7,7 +7,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
#[derive(Debug, Clone)]
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@ -16,13 +15,14 @@ struct CausalSelfAttention {
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
@ -36,13 +36,7 @@ impl CausalSelfAttention {
Ok(rope)
}
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
@ -52,15 +46,16 @@ impl CausalSelfAttention {
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
}
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
cache[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
@ -71,7 +66,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
@ -95,7 +90,7 @@ impl CausalSelfAttention {
}
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
@ -111,6 +106,7 @@ impl CausalSelfAttention {
n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
})
}
}
@ -122,7 +118,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m)
}
#[derive(Debug, Clone)]
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
@ -153,7 +148,6 @@ impl Mlp {
}
}
#[derive(Debug, Clone)]
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
@ -171,23 +165,17 @@ impl Block {
}
}
fn forward(
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
@ -201,7 +189,6 @@ impl Block {
}
}
#[derive(Debug, Clone)]
pub struct QLlama {
wte: Embedding,
blocks: Vec<Block>,
@ -211,23 +198,23 @@ pub struct QLlama {
}
impl QLlama {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?;
x = block.forward(&x, index_pos, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap())
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect();
Ok(Self {
wte,

View File

@ -1,286 +0,0 @@
use crate::{
quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
quantized_var_builder::VarBuilder,
};
use candle::{IndexOp, Result, Tensor};
use candle_nn::{GroupNorm, LayerNorm, Module};
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
#[derive(Debug, Clone)]
struct SelfAttention {
key: Linear,
receptance: Linear,
value: Linear,
gate: Linear,
output: Linear,
ln_x: candle_nn::GroupNorm,
time_mix_key: Tensor,
time_mix_value: Tensor,
time_mix_receptance: Tensor,
time_decay: Tensor,
time_faaaa: Tensor,
time_mix_gate: Tensor,
layer_id: usize,
n_attn_heads: usize,
}
impl SelfAttention {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let attn_hidden_size = cfg.attention_hidden_size;
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
let vb_x = vb.pp("ln_x");
let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
let ln_x = GroupNorm::new(
ln_x_weight,
ln_x_bias,
hidden_size,
hidden_size / cfg.head_size,
1e-5,
)?;
let time_mix_key = vb
.get((1, 1, cfg.hidden_size), "time_mix_key")?
.dequantize(vb.device())?;
let time_mix_value = vb
.get((1, 1, cfg.hidden_size), "time_mix_value")?
.dequantize(vb.device())?;
let time_mix_receptance = vb
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
.dequantize(vb.device())?;
let n_attn_heads = cfg.hidden_size / cfg.head_size;
let time_decay = vb
.get((n_attn_heads, cfg.head_size), "time_decay")?
.dequantize(vb.device())?;
let time_faaaa = vb
.get((n_attn_heads, cfg.head_size), "time_faaaa")?
.dequantize(vb.device())?;
let time_mix_gate = vb
.get((1, 1, cfg.hidden_size), "time_mix_gate")?
.dequantize(vb.device())?;
Ok(Self {
key,
value,
receptance,
gate,
output,
ln_x,
time_mix_key,
time_mix_value,
time_mix_receptance,
time_decay,
time_faaaa,
time_mix_gate,
layer_id,
n_attn_heads,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let h = self.time_decay.dim(0)?;
let (b, t, s) = xs.dims3()?;
let s = s / h;
let (receptance, key, value, gate) = {
// extract key-value
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
let shifted = if shifted.rank() == 2 {
shifted.unsqueeze(1)?
} else {
shifted
};
let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
let receptance = ((xs * &self.time_mix_receptance)?
+ &shifted * (1.0 - &self.time_mix_receptance)?)?;
let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
let key = self.key.forward(&key)?;
let value = self.value.forward(&value)?;
let receptance = self.receptance.forward(&receptance)?;
let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
(receptance, key, value, gate)
};
// linear attention
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
let time_decay = self
.time_decay
.exp()?
.neg()?
.exp()?
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let time_faaaa =
self.time_faaaa
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
state_ = (&at + time_decay.broadcast_mul(&state_))?;
out.push(out_)
}
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
let out = (out * gate)?.apply(&self.output)?;
state.per_layer[self.layer_id].linear_attention = state_;
Ok(out)
}
}
#[derive(Debug, Clone)]
struct FeedForward {
time_mix_key: Tensor,
time_mix_receptance: Tensor,
key: Linear,
receptance: Linear,
value: Linear,
layer_id: usize,
}
impl FeedForward {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let int_size = cfg
.intermediate_size
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
let time_mix_key = vb
.get((1, 1, cfg.hidden_size), "time_mix_key")?
.dequantize(vb.device())?;
let time_mix_receptance = vb
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
.dequantize(vb.device())?;
Ok(Self {
key,
receptance,
value,
time_mix_key,
time_mix_receptance,
layer_id,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let shifted = &state.per_layer[self.layer_id].feed_forward;
let key = (xs.broadcast_mul(&self.time_mix_key)?
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
let key = key.apply(&self.key)?.relu()?.sqr()?;
let value = key.apply(&self.value)?;
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
let xs = (receptance * value)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct Block {
pre_ln: Option<LayerNorm>,
ln1: LayerNorm,
ln2: LayerNorm,
attention: SelfAttention,
feed_forward: FeedForward,
}
impl Block {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
let pre_ln = if layer_id == 0 {
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
Some(ln)
} else {
None
};
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
Ok(Self {
pre_ln,
ln1,
ln2,
attention,
feed_forward,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let xs = match self.pre_ln.as_ref() {
None => xs.clone(),
Some(pre_ln) => xs.apply(pre_ln)?,
};
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
let xs = (xs + attention)?;
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
let xs = (xs + feed_forward)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct Model {
embeddings: Embedding,
blocks: Vec<Block>,
ln_out: LayerNorm,
head: Linear,
rescale_every: usize,
layers_are_rescaled: bool,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("rwkv");
let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
let vb_b = vb_m.pp("blocks");
for block_index in 0..cfg.num_hidden_layers {
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
blocks.push(block)
}
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
Ok(Self {
embeddings,
blocks,
ln_out,
head,
rescale_every: cfg.rescale_every,
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let (_b_size, _seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embeddings)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
xs = block.forward(&xs, state)?;
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
xs = (xs / 2.)?
}
}
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
state.pos += 1;
Ok(xs)
}
}

View File

@ -1,332 +0,0 @@
use crate::{
quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
quantized_var_builder::VarBuilder,
};
use candle::{IndexOp, Result, Tensor};
use candle_nn::{GroupNorm, LayerNorm, Module};
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
#[derive(Debug, Clone)]
struct SelfAttention {
key: Linear,
receptance: Linear,
value: Linear,
gate: Linear,
output: Linear,
ln_x: candle_nn::GroupNorm,
time_mix_x: Tensor,
time_mix_w: Tensor,
time_mix_key: Tensor,
time_mix_value: Tensor,
time_mix_receptance: Tensor,
time_decay: Tensor,
time_faaaa: Tensor,
time_mix_gate: Tensor,
time_decay_w1: Tensor,
time_decay_w2: Tensor,
time_mix_w1: Tensor,
time_mix_w2: Tensor,
layer_id: usize,
n_attn_heads: usize,
}
impl SelfAttention {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let attn_hidden_size = cfg.attention_hidden_size;
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
let vb_x = vb.pp("ln_x");
let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
let ln_x = GroupNorm::new(
ln_x_weight,
ln_x_bias,
hidden_size,
hidden_size / cfg.head_size,
1e-5,
)?;
let time_mix_x = vb
.get((1, 1, cfg.hidden_size), "time_mix_x")?
.dequantize(vb.device())?;
let time_mix_w = vb
.get((1, 1, cfg.hidden_size), "time_mix_w")?
.dequantize(vb.device())?;
let time_mix_key = vb
.get((1, 1, cfg.hidden_size), "time_mix_key")?
.dequantize(vb.device())?;
let time_mix_value = vb
.get((1, 1, cfg.hidden_size), "time_mix_value")?
.dequantize(vb.device())?;
let time_mix_receptance = vb
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
.dequantize(vb.device())?;
let n_attn_heads = cfg.hidden_size / cfg.head_size;
let time_decay = vb
.get((1, 1, cfg.hidden_size), "time_decay")?
.dequantize(vb.device())?;
let time_faaaa = vb
.get((n_attn_heads, cfg.head_size), "time_faaaa")?
.dequantize(vb.device())?;
let time_mix_gate = vb
.get((1, 1, cfg.hidden_size), "time_mix_gate")?
.dequantize(vb.device())?;
let time_decay_w1 = vb
.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?
.dequantize(vb.device())?;
let time_decay_w2 = vb
.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?
.dequantize(vb.device())?;
let time_mix_w1 = vb
.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?
.dequantize(vb.device())?;
let time_mix_w2 = vb
.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?
.dequantize(vb.device())?;
Ok(Self {
key,
value,
receptance,
gate,
output,
ln_x,
time_mix_x,
time_mix_w,
time_mix_key,
time_mix_value,
time_mix_receptance,
time_decay,
time_faaaa,
time_mix_gate,
time_decay_w1,
time_decay_w2,
time_mix_w1,
time_mix_w2,
layer_id,
n_attn_heads,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let h = self.n_attn_heads;
let (b, t, s) = xs.dims3()?;
let s = s / h;
let (receptance, key, value, gate, w) = {
// extract key-value
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
let shifted = if shifted.rank() == 2 {
shifted.unsqueeze(1)?
} else {
shifted
};
let sx = (&shifted - xs)?;
let xxx = (xs + &sx * &self.time_mix_x)?;
let xxx = xxx
.broadcast_matmul(&self.time_mix_w1)?
.tanh()?
.reshape((b * t, 5, ()))?
.transpose(0, 1)?;
let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
let w = (&self.time_decay
+ xw.broadcast_matmul(&self.time_decay_w1)?
.tanh()?
.broadcast_matmul(&self.time_decay_w2)?)?
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let key = self.key.forward(&xk)?;
let value = self.value.forward(&xv)?;
let receptance = self.receptance.forward(&xr)?;
let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
(receptance, key, value, gate, w)
};
// linear attention
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
let w = w.exp()?.neg()?.exp()?;
let time_faaaa =
self.time_faaaa
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
state_ = (&at + w.broadcast_mul(&state_))?;
out.push(out_)
}
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
let out = (out * gate)?.apply(&self.output)?;
state.per_layer[self.layer_id].linear_attention = state_;
Ok(out)
}
}
#[derive(Debug, Clone)]
struct FeedForward {
time_mix_key: Tensor,
time_mix_receptance: Tensor,
key: Linear,
receptance: Linear,
value: Linear,
layer_id: usize,
}
impl FeedForward {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let int_size = cfg
.intermediate_size
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
let time_mix_key = vb
.get((1, 1, cfg.hidden_size), "time_mix_key")?
.dequantize(vb.device())?;
let time_mix_receptance = vb
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
.dequantize(vb.device())?;
Ok(Self {
key,
receptance,
value,
time_mix_key,
time_mix_receptance,
layer_id,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let shifted = state.per_layer[self.layer_id]
.feed_forward
.broadcast_sub(xs)?;
let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
let key = key.apply(&self.key)?.relu()?.sqr()?;
let value = key.apply(&self.value)?;
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
let xs = (receptance * value)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct Block {
pre_ln: Option<LayerNorm>,
ln1: LayerNorm,
ln2: LayerNorm,
attention: SelfAttention,
feed_forward: FeedForward,
}
impl Block {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
let pre_ln = if layer_id == 0 {
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
Some(ln)
} else {
None
};
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
Ok(Self {
pre_ln,
ln1,
ln2,
attention,
feed_forward,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let xs = match self.pre_ln.as_ref() {
None => xs.clone(),
Some(pre_ln) => xs.apply(pre_ln)?,
};
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
let xs = (xs + attention)?;
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
let xs = (xs + feed_forward)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct Model {
embeddings: Embedding,
blocks: Vec<Block>,
ln_out: LayerNorm,
head: Linear,
rescale_every: usize,
layers_are_rescaled: bool,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("rwkv");
let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
let vb_b = vb_m.pp("blocks");
for block_index in 0..cfg.num_hidden_layers {
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
blocks.push(block)
}
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
Ok(Self {
embeddings,
blocks,
ln_out,
head,
rescale_every: cfg.rescale_every,
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let (_b_size, _seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embeddings)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
xs = block.forward(&xs, state)?;
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
xs = (xs / 2.)?
}
}
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
state.pos += 1;
Ok(xs)
}
}

View File

@ -186,14 +186,10 @@ impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_eps,
vb.pp("input_layernorm"),
)?;
let input_layernorm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_eps,
cfg.norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
@ -244,7 +240,7 @@ impl Model {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?;
let norm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,

View File

@ -22,15 +22,15 @@ pub struct Config {
pub rescale_every: usize,
}
pub struct StatePerLayer {
pub extract_key_value: Tensor,
pub linear_attention: Tensor,
pub feed_forward: Tensor,
struct StatePerLayer {
extract_key_value: Tensor,
linear_attention: Tensor,
feed_forward: Tensor,
}
pub struct State {
pub per_layer: Vec<StatePerLayer>,
pub pos: usize,
per_layer: Vec<StatePerLayer>,
pos: usize,
}
impl State {
@ -124,7 +124,7 @@ impl SelfAttention {
let (b, t, s) = xs.dims3()?;
let s = s / h;
let (receptance, key, value, gate) = {
// extract key-value
// exctract key-value
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
let shifted = if shifted.rank() == 2 {
shifted.unsqueeze(1)?
@ -164,9 +164,10 @@ impl SelfAttention {
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
//
let rt = receptance.i((.., .., t_..t_ + 1))?;
let kt = key.i((.., .., .., t_..t_ + 1))?;
let vt = value.i((.., .., t_..t_ + 1))?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;

View File

@ -1,295 +0,0 @@
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
use candle::{IndexOp, Result, Tensor};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
#[derive(Debug, Clone)]
struct SelfAttention {
key: Linear,
receptance: Linear,
value: Linear,
gate: Linear,
output: Linear,
ln_x: candle_nn::GroupNorm,
time_mix_x: Tensor,
time_mix_w: Tensor,
time_mix_key: Tensor,
time_mix_value: Tensor,
time_mix_receptance: Tensor,
time_decay: Tensor,
time_faaaa: Tensor,
time_mix_gate: Tensor,
time_decay_w1: Tensor,
time_decay_w2: Tensor,
time_mix_w1: Tensor,
time_mix_w2: Tensor,
layer_id: usize,
n_attn_heads: usize,
}
impl SelfAttention {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let attn_hidden_size = cfg.attention_hidden_size;
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
let ln_x = candle_nn::group_norm(
hidden_size / cfg.head_size,
hidden_size,
1e-5,
vb.pp("ln_x"),
)?;
let time_mix_x = vb.get((1, 1, cfg.hidden_size), "time_mix_x")?;
let time_mix_w = vb.get((1, 1, cfg.hidden_size), "time_mix_w")?;
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
let n_attn_heads = cfg.hidden_size / cfg.head_size;
let time_decay = vb.get((1, 1, cfg.hidden_size), "time_decay")?;
let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
let time_decay_w1 = vb.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?;
let time_decay_w2 = vb.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?;
let time_mix_w1 = vb.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?;
let time_mix_w2 = vb.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?;
Ok(Self {
key,
value,
receptance,
gate,
output,
ln_x,
time_mix_x,
time_mix_w,
time_mix_key,
time_mix_value,
time_mix_receptance,
time_decay,
time_faaaa,
time_mix_gate,
time_decay_w1,
time_decay_w2,
time_mix_w1,
time_mix_w2,
layer_id,
n_attn_heads,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let h = self.n_attn_heads;
let (b, t, s) = xs.dims3()?;
let s = s / h;
let (receptance, key, value, gate, w) = {
// extract key-value
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
let shifted = if shifted.rank() == 2 {
shifted.unsqueeze(1)?
} else {
shifted
};
let sx = (&shifted - xs)?;
let xxx = (xs + &sx * &self.time_mix_x)?;
let xxx = xxx
.broadcast_matmul(&self.time_mix_w1)?
.tanh()?
.reshape((b * t, 5, ()))?
.transpose(0, 1)?;
let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
let w = (&self.time_decay
+ xw.broadcast_matmul(&self.time_decay_w1)?
.tanh()?
.broadcast_matmul(&self.time_decay_w2)?)?
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let key = self.key.forward(&xk)?;
let value = self.value.forward(&xv)?;
let receptance = self.receptance.forward(&xr)?;
let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
(receptance, key, value, gate, w)
};
// linear attention
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
let w = w.exp()?.neg()?.exp()?;
let time_faaaa =
self.time_faaaa
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
state_ = (&at + w.broadcast_mul(&state_))?;
out.push(out_)
}
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
let out = (out * gate)?.apply(&self.output)?;
state.per_layer[self.layer_id].linear_attention = state_;
Ok(out)
}
}
#[derive(Debug, Clone)]
struct FeedForward {
time_mix_key: Tensor,
time_mix_receptance: Tensor,
key: Linear,
receptance: Linear,
value: Linear,
layer_id: usize,
}
impl FeedForward {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let int_size = cfg
.intermediate_size
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
Ok(Self {
key,
receptance,
value,
time_mix_key,
time_mix_receptance,
layer_id,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let shifted = state.per_layer[self.layer_id]
.feed_forward
.broadcast_sub(xs)?;
let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
let key = key.apply(&self.key)?.relu()?.sqr()?;
let value = key.apply(&self.value)?;
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
let xs = (receptance * value)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct Block {
pre_ln: Option<LayerNorm>,
ln1: LayerNorm,
ln2: LayerNorm,
attention: SelfAttention,
feed_forward: FeedForward,
}
impl Block {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
let pre_ln = if layer_id == 0 {
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
Some(ln)
} else {
None
};
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
Ok(Self {
pre_ln,
ln1,
ln2,
attention,
feed_forward,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let xs = match self.pre_ln.as_ref() {
None => xs.clone(),
Some(pre_ln) => xs.apply(pre_ln)?,
};
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
let xs = (xs + attention)?;
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
let xs = (xs + feed_forward)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct Model {
embeddings: Embedding,
blocks: Vec<Block>,
ln_out: LayerNorm,
head: Linear,
rescale_every: usize,
layers_are_rescaled: bool,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("rwkv");
let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
let vb_b = vb_m.pp("blocks");
for block_index in 0..cfg.num_hidden_layers {
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
blocks.push(block)
}
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
Ok(Self {
embeddings,
blocks,
ln_out,
head,
rescale_every: cfg.rescale_every,
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let (_b_size, _seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embeddings)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
xs = block.forward(&xs, state)?;
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
xs = (xs / 2.)?
}
}
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
state.pos += 1;
Ok(xs)
}
}

View File

@ -1,705 +0,0 @@
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, ModuleT, Result, Tensor, D};
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
#[serde(default)]
pub id2label: HashMap<String, String>,
pub num_channels: usize,
pub num_encoder_blocks: usize,
pub depths: Vec<usize>,
pub sr_ratios: Vec<usize>,
pub hidden_sizes: Vec<usize>,
pub patch_sizes: Vec<usize>,
pub strides: Vec<usize>,
pub num_attention_heads: Vec<usize>,
pub mlp_ratios: Vec<usize>,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
pub decoder_hidden_size: usize,
}
#[derive(Debug, Clone)]
struct SegformerOverlapPatchEmbeddings {
projection: Conv2d,
layer_norm: candle_nn::LayerNorm,
}
impl SegformerOverlapPatchEmbeddings {
fn new(
config: &Config,
patch_size: usize,
stride: usize,
num_channels: usize,
hidden_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let projection = conv2d(
num_channels,
hidden_size,
patch_size,
Conv2dConfig {
stride,
padding: patch_size / 2,
..Default::default()
},
vb.pp("proj"),
)?;
let layer_norm =
candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?;
Ok(Self {
projection,
layer_norm,
})
}
}
impl Module for SegformerOverlapPatchEmbeddings {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let embeddings = self.projection.forward(x)?;
let shape = embeddings.shape();
// [B, C, H, W] -> [B, H * W, C]
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
let embeddings = self.layer_norm.forward(&embeddings)?;
// [B, H * W, C] -> [B, C, H, W]
let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?;
Ok(embeddings)
}
}
#[derive(Debug, Clone)]
struct SegformerEfficientSelfAttention {
num_attention_heads: usize,
attention_head_size: usize,
query: Linear,
key: Linear,
value: Linear,
sr: Option<Conv2d>,
layer_norm: Option<layer_norm::LayerNorm>,
}
impl SegformerEfficientSelfAttention {
fn new(
config: &Config,
hidden_size: usize,
num_attention_heads: usize,
sequence_reduction_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
if hidden_size % num_attention_heads != 0 {
candle::bail!(
"The hidden size {} is not a multiple of the number of attention heads {}",
hidden_size,
num_attention_heads
)
}
let attention_head_size = hidden_size / num_attention_heads;
let all_head_size = num_attention_heads * attention_head_size;
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
let (sr, layer_norm) = if sequence_reduction_ratio > 1 {
(
Some(conv2d(
hidden_size,
hidden_size,
sequence_reduction_ratio,
Conv2dConfig {
stride: sequence_reduction_ratio,
..Default::default()
},
vb.pp("sr"),
)?),
Some(candle_nn::layer_norm(
hidden_size,
config.layer_norm_eps,
vb.pp("layer_norm"),
)?),
)
} else {
(None, None)
};
Ok(Self {
num_attention_heads,
attention_head_size,
query,
key,
value,
sr,
layer_norm,
})
}
fn transpose_for_scores(&self, hidden_states: Tensor) -> Result<Tensor> {
let (batch, seq_length, _) = hidden_states.shape().dims3()?;
let new_shape = &[
batch,
seq_length,
self.num_attention_heads,
self.attention_head_size,
];
let hidden_states = hidden_states.reshape(new_shape)?;
let hidden_states = hidden_states.permute((0, 2, 1, 3))?;
Ok(hidden_states)
}
}
impl Module for SegformerEfficientSelfAttention {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// [B, C, H, W] -> [B, H * W, C]
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
let query = self
.transpose_for_scores(self.query.forward(&hidden_states)?)?
.contiguous()?;
let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) {
let hidden_states = sr.forward(x)?;
// [B, C, H, W] -> [B, H * W, C]
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
layer_norm.forward(&hidden_states)?
} else {
// already [B, H * W, C]
hidden_states
};
// standard self-attention
let key = self
.transpose_for_scores(self.key.forward(&hidden_states)?)?
.contiguous()?;
let value = self
.transpose_for_scores(self.value.forward(&hidden_states)?)?
.contiguous()?;
let attention_scores =
(query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let result = attention_scores.matmul(&value)?;
let result = result.permute((0, 2, 1, 3))?.contiguous()?;
result.flatten_from(D::Minus2)
}
}
#[derive(Debug, Clone)]
struct SegformerSelfOutput {
dense: Linear,
}
impl SegformerSelfOutput {
fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?;
Ok(Self { dense })
}
}
impl Module for SegformerSelfOutput {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.dense.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerAttention {
attention: SegformerEfficientSelfAttention,
output: SegformerSelfOutput,
}
impl SegformerAttention {
fn new(
config: &Config,
hidden_size: usize,
num_attention_heads: usize,
sequence_reduction_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
let attention = SegformerEfficientSelfAttention::new(
config,
hidden_size,
num_attention_heads,
sequence_reduction_ratio,
vb.pp("self"),
)?;
let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?;
Ok(Self { attention, output })
}
}
impl Module for SegformerAttention {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let attention_output = self.attention.forward(x)?;
self.output.forward(&attention_output)
}
}
#[derive(Debug, Clone)]
struct SegformerDWConv {
dw_conv: Conv2d,
}
impl SegformerDWConv {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let dw_conv = conv2d(
dim,
dim,
3,
Conv2dConfig {
stride: 1,
padding: 1,
groups: dim,
..Default::default()
},
vb.pp("dwconv"),
)?;
Ok(Self { dw_conv })
}
}
impl Module for SegformerDWConv {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.dw_conv.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerMixFFN {
dense1: Linear,
dw_conv: SegformerDWConv,
act: Activation,
dense2: Linear,
}
impl SegformerMixFFN {
fn new(
config: &Config,
in_features: usize,
hidden_features: usize,
out_features: usize,
vb: VarBuilder,
) -> Result<Self> {
let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?;
let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?;
let act = config.hidden_act;
let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?;
Ok(Self {
dense1,
dw_conv,
act,
dense2,
})
}
}
impl Module for SegformerMixFFN {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (batch, _, height, width) = x.shape().dims4()?;
let hidden_states = self
.dense1
.forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?;
let channels = hidden_states.dim(2)?;
let hidden_states = self.dw_conv.forward(
&hidden_states
.permute((0, 2, 1))?
.reshape((batch, channels, height, width))?,
)?;
let hidden_states = self.act.forward(&hidden_states)?;
let hidden_states = self
.dense2
.forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
let channels = hidden_states.dim(2)?;
hidden_states
.permute((0, 2, 1))?
.reshape((batch, channels, height, width))
}
}
#[derive(Debug, Clone)]
struct SegformerLayer {
layer_norm_1: candle_nn::LayerNorm,
attention: SegformerAttention,
layer_norm_2: candle_nn::LayerNorm,
mlp: SegformerMixFFN,
}
impl SegformerLayer {
fn new(
config: &Config,
hidden_size: usize,
num_attention_heads: usize,
sequence_reduction_ratio: usize,
mlp_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?;
let attention = SegformerAttention::new(
config,
hidden_size,
num_attention_heads,
sequence_reduction_ratio,
vb.pp("attention"),
)?;
let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?;
let mlp = SegformerMixFFN::new(
config,
hidden_size,
hidden_size * mlp_ratio,
hidden_size,
vb.pp("mlp"),
)?;
Ok(Self {
layer_norm_1,
attention,
layer_norm_2,
mlp,
})
}
}
impl Module for SegformerLayer {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let shape = x.shape().dims4()?;
// [B, C, H, W] -> [B, H * W, C]
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
let layer_norm_output = self.layer_norm_1.forward(&hidden_states)?;
let layer_norm_output = layer_norm_output.permute((0, 2, 1))?.reshape(shape)?;
// attention takes in [B, C, H, W] in order to properly do conv2d (and output [B, H * W, C])
let attention_output = self.attention.forward(&layer_norm_output)?;
let hidden_states = (attention_output + hidden_states)?;
let layer_norm_output = self.layer_norm_2.forward(&hidden_states)?;
let mlp_output = self
.mlp
.forward(&layer_norm_output.permute((0, 2, 1))?.reshape(shape)?)?;
hidden_states.permute((0, 2, 1))?.reshape(shape)? + mlp_output
}
}
#[derive(Debug, Clone)]
struct SegformerEncoder {
/// config file
config: Config,
/// a list of embeddings
patch_embeddings: Vec<SegformerOverlapPatchEmbeddings>,
/// a list of attention blocks, each consisting of layers
blocks: Vec<Vec<SegformerLayer>>,
/// a final list of layer norms
layer_norms: Vec<candle_nn::LayerNorm>,
}
impl SegformerEncoder {
fn new(config: Config, vb: VarBuilder) -> Result<Self> {
let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks);
let mut blocks = Vec::with_capacity(config.num_encoder_blocks);
let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks);
for i in 0..config.num_encoder_blocks {
let patch_size = config.patch_sizes[i];
let stride = config.strides[i];
let hidden_size = config.hidden_sizes[i];
let num_channels = if i == 0 {
config.num_channels
} else {
config.hidden_sizes[i - 1]
};
patch_embeddings.push(SegformerOverlapPatchEmbeddings::new(
&config,
patch_size,
stride,
num_channels,
hidden_size,
vb.pp(&format!("patch_embeddings.{}", i)),
)?);
let mut layers = Vec::with_capacity(config.depths[i]);
for j in 0..config.depths[i] {
let sequence_reduction_ratio = config.sr_ratios[i];
let num_attention_heads = config.num_attention_heads[i];
let mlp_ratio = config.mlp_ratios[i];
layers.push(SegformerLayer::new(
&config,
hidden_size,
num_attention_heads,
sequence_reduction_ratio,
mlp_ratio,
vb.pp(&format!("block.{}.{}", i, j)),
)?);
}
blocks.push(layers);
layer_norms.push(layer_norm(
hidden_size,
config.layer_norm_eps,
vb.pp(&format!("layer_norm.{}", i)),
)?);
}
Ok(Self {
config,
patch_embeddings,
blocks,
layer_norms,
})
}
}
impl ModuleWithHiddenStates for SegformerEncoder {
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks);
let mut hidden_states = x.clone();
for i in 0..self.config.num_encoder_blocks {
hidden_states = self.patch_embeddings[i].forward(&hidden_states)?;
for layer in &self.blocks[i] {
hidden_states = layer.forward(&hidden_states)?;
}
let shape = hidden_states.shape().dims4()?;
hidden_states =
self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
hidden_states = hidden_states.permute((0, 2, 1))?.reshape(shape)?;
all_hidden_states.push(hidden_states.clone());
}
Ok(all_hidden_states)
}
}
#[derive(Debug, Clone)]
struct SegformerModel {
encoder: SegformerEncoder,
}
impl SegformerModel {
fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?;
Ok(Self { encoder })
}
}
impl ModuleWithHiddenStates for SegformerModel {
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
self.encoder.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerMLP {
proj: Linear,
}
impl SegformerMLP {
fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result<Self> {
let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("proj"))?;
Ok(Self { proj })
}
}
impl Module for SegformerMLP {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.proj.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerDecodeHead {
linear_c: Vec<SegformerMLP>,
linear_fuse: candle_nn::Conv2d,
batch_norm: candle_nn::BatchNorm,
classifier: candle_nn::Conv2d,
}
impl SegformerDecodeHead {
fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let mut linear_c = Vec::with_capacity(config.num_encoder_blocks);
for i in 0..config.num_encoder_blocks {
let hidden_size = config.hidden_sizes[i];
linear_c.push(SegformerMLP::new(
config,
hidden_size,
vb.pp(&format!("linear_c.{}", i)),
)?);
}
let linear_fuse = conv2d_no_bias(
config.decoder_hidden_size * config.num_encoder_blocks,
config.decoder_hidden_size,
1,
Conv2dConfig::default(),
vb.pp("linear_fuse"),
)?;
let batch_norm = candle_nn::batch_norm(
config.decoder_hidden_size,
config.layer_norm_eps,
vb.pp("batch_norm"),
)?;
let classifier = conv2d_no_bias(
config.decoder_hidden_size,
num_labels,
1,
Conv2dConfig::default(),
vb.pp("classifier"),
)?;
Ok(Self {
linear_c,
linear_fuse,
batch_norm,
classifier,
})
}
fn forward(&self, encoder_hidden_states: &[Tensor]) -> Result<Tensor> {
if encoder_hidden_states.len() != self.linear_c.len() {
candle::bail!(
"The number of encoder hidden states {} is not equal to the number of linear layers {}",
encoder_hidden_states.len(),
self.linear_c.len()
)
}
// most fine layer
let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?;
let mut hidden_states = Vec::with_capacity(self.linear_c.len());
for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) {
let (batch, _, height, width) = hidden_state.shape().dims4()?;
let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?;
let hidden_state = hidden_state.permute((0, 2, 1))?.reshape((
batch,
hidden_state.dim(2)?,
height,
width,
))?;
let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?;
hidden_states.push(hidden_state);
}
hidden_states.reverse();
let hidden_states = Tensor::cat(&hidden_states, 1)?;
let hidden_states = self.linear_fuse.forward(&hidden_states)?;
let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?;
let hidden_states = hidden_states.relu()?;
self.classifier.forward(&hidden_states)
}
}
trait ModuleWithHiddenStates {
fn forward(&self, xs: &Tensor) -> Result<Vec<Tensor>>;
}
#[derive(Debug, Clone)]
pub struct SemanticSegmentationModel {
segformer: SegformerModel,
decode_head: SegformerDecodeHead,
}
impl SemanticSegmentationModel {
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?;
Ok(Self {
segformer,
decode_head,
})
}
}
impl Module for SemanticSegmentationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let hidden_states = self.segformer.forward(x)?;
self.decode_head.forward(&hidden_states)
}
}
#[derive(Debug, Clone)]
pub struct ImageClassificationModel {
segformer: SegformerModel,
classifier: Linear,
}
impl ImageClassificationModel {
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
let classifier = linear(config.decoder_hidden_size, num_labels, vb.pp("classifier"))?;
Ok(Self {
segformer,
classifier,
})
}
}
impl Module for ImageClassificationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let all_hidden_states = self.segformer.forward(x)?;
let hidden_states = all_hidden_states.last().unwrap();
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
let mean = hidden_states.mean(1)?;
self.classifier.forward(&mean)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_json_load() {
let raw_json = r#"{
"architectures": [
"SegformerForImageClassification"
],
"attention_probs_dropout_prob": 0.0,
"classifier_dropout_prob": 0.1,
"decoder_hidden_size": 256,
"depths": [
2,
2,
2,
2
],
"downsampling_rates": [
1,
4,
8,
16
],
"drop_path_rate": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_sizes": [
32,
64,
160,
256
],
"image_size": 224,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"mlp_ratios": [
4,
4,
4,
4
],
"model_type": "segformer",
"num_attention_heads": [
1,
2,
5,
8
],
"num_channels": 3,
"num_encoder_blocks": 4,
"patch_sizes": [
7,
3,
3,
3
],
"sr_ratios": [
8,
4,
2,
1
],
"strides": [
4,
2,
2,
2
],
"torch_dtype": "float32",
"transformers_version": "4.12.0.dev0"
}"#;
let config: Config = serde_json::from_str(raw_json).unwrap();
assert_eq!(vec![4, 2, 2, 2], config.strides);
assert_eq!(1e-6, config.layer_norm_eps);
}
}

View File

@ -4,7 +4,7 @@ use candle_nn::{Activation, LayerNorm, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm.py
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
@ -14,10 +14,10 @@ pub struct Config {
pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: usize,
pub(crate) hidden_act: Activation,
pub(crate) partial_rotary_factor: f64,
pub(crate) rope_pct: f64,
pub(crate) rope_theta: f64,
pub(crate) max_position_embeddings: usize,
pub(crate) layer_norm_eps: f64,
pub(crate) norm_eps: f64,
pub(crate) use_cache: bool,
#[serde(default)]
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
@ -35,10 +35,10 @@ impl Config {
num_attention_heads: 32,
num_key_value_heads: 32,
hidden_act: Activation::Silu,
partial_rotary_factor: 0.25,
rope_pct: 0.25,
rope_theta: 10_000.,
max_position_embeddings: 4096,
layer_norm_eps: 1e-5,
norm_eps: 1e-5,
use_qkv_bias: false,
use_cache: true,
use_flash_attn,
@ -50,7 +50,7 @@ impl Config {
}
pub fn rotary_ndims(&self) -> usize {
(self.head_dim() as f64 * self.partial_rotary_factor) as usize
(self.head_dim() as f64 * self.rope_pct) as usize
}
pub fn num_kv_groups(&self) -> usize {
@ -316,14 +316,11 @@ impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm = candle_nn::layer_norm(
cfg.hidden_size,
cfg.layer_norm_eps,
vb.pp("input_layernorm"),
)?;
let input_layernorm =
candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = candle_nn::layer_norm(
cfg.hidden_size,
cfg.layer_norm_eps,
cfg.norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
@ -375,7 +372,7 @@ impl Model {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?;
let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,

View File

@ -1,347 +0,0 @@
#![allow(unused)]
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{layer_norm, linear_b, LayerNorm, Linear, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
vocab_size: usize,
hidden_size: usize,
intermediate_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
num_key_value_heads: usize,
hidden_act: candle_nn::Activation,
max_position_embeddings: usize,
norm_epsilon: f64,
rope_theta: f64,
use_bias: bool,
sliding_window: Option<usize>,
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
c_fc: Linear,
c_proj: Linear,
act: candle_nn::Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
let c_fc = linear_b(h_size, i_size, cfg.use_bias, vb.pp("c_fc"))?;
let c_proj = linear_b(i_size, h_size, cfg.use_bias, vb.pp("c_proj"))?;
Ok(Self {
c_fc,
c_proj,
act: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.c_fc)?.apply(&self.act)?.apply(&self.c_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let b = cfg.use_bias;
let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size: hidden_sz,
rotary_emb,
kv_cache: None,
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&value_states)?;
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: LayerNorm,
post_attention_layernorm: LayerNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
layer_norm(cfg.hidden_size, cfg.norm_epsilon, vb.pp("input_layernorm"))?;
let post_attention_layernorm = layer_norm(
cfg.hidden_size,
cfg.norm_epsilon,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: LayerNorm,
lm_head: Linear,
sliding_window: Option<usize>,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = layer_norm(cfg.hidden_size, cfg.norm_epsilon, vb_m.pp("norm"))?;
let lm_head = candle_nn::Linear::new(embed_tokens.embeddings().clone(), None);
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.sliding_window,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let sliding_window = self.sliding_window.unwrap_or(tgt_len + 42);
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -0,0 +1,156 @@
#![allow(unused)]
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
pub struct AdaLayerNorm {
eps: f64,
dim: usize,
scale: Embedding,
shift: Embedding,
}
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = {
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
};
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)
}
impl AdaLayerNorm {
pub fn new(
num_embeddings: usize,
embedding_dim: usize,
eps: f64,
vb: VarBuilder,
) -> Result<Self> {
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
Ok(Self {
eps,
dim: embedding_dim,
scale,
shift,
})
}
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
let scale = self.scale.forward(cond_embedding_id)?;
let shift = self.shift.forward(cond_embedding_id)?;
let xs = layer_norm(xs, self.eps)?;
xs * scale + shift
}
}
pub struct ConvNeXtBlock {
dwconv: Conv1d,
pwconv1: Linear,
pwconv2: Linear,
gamma: Option<Tensor>,
}
impl ConvNeXtBlock {
pub fn new(
dim: usize,
intermediate_dim: usize,
layer_scale_init_value: f64,
adanorm_num_embeddings: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let dwconv = {
let cfg = Conv1dConfig {
padding: 3,
groups: dim,
..Default::default()
};
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
};
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
let gamma = if layer_scale_init_value > 0. {
Some(vb.get(dim, "gamma")?)
} else {
None
};
Ok(Self {
dwconv,
pwconv1,
pwconv2,
gamma,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
// TODO: norm
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
let xs = match self.gamma.as_ref() {
Some(gamma) => (gamma * xs)?,
None => xs,
};
xs.transpose(1, 2)? + residual
}
}
struct VocosBackbone {
embed: Conv1d,
convnext: Vec<ConvNeXtBlock>,
final_layer_norm: candle_nn::LayerNorm,
}
impl VocosBackbone {
pub fn new(
input_channels: usize,
dim: usize,
intermediate_dim: usize,
num_layers: dim,
layer_scale_init_value: f64,
adanorm_num_embeddings: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let embed = {
let cfg = Conv1dConfig {
padding: 3,
..Default::default()
};
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
};
let mut convnext = Vec::with_capacity(num_layers);
let vb_c = vb.pp("convnext");
for i in 0..num_layers {
let block = ConvNeXtBlock::new(
dim,
intermediate_dim,
layer_scale_init_value,
adanorm_num_embeddings,
vb_c.pp(i),
)?;
}
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
Ok(Self {
embed,
convnext,
final_layer_norm,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embed)?;
// TODO: norm
let mut xs = xs.transpose(1, 2)?;
for conv_block in self.convnext.iter() {
xs = conv_block.forward(&xs)?
}
xs.apply(&self.final_layer_norm)
}
}

View File

@ -167,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
mel
}
pub fn log_mel_spectrogram_<T: Float>(
fn log_mel_spectrogram_<T: Float>(
samples: &[T],
filters: &[T],
fft_size: usize,

View File

@ -47,12 +47,6 @@ impl Linear {
}
}
pub fn linear_b(d1: usize, d2: usize, b: bool, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear_b(d1, d2, b, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");