Compare commits

..

1 Commits

Author SHA1 Message Date
3f3730b657 Preliminary implementation for the vocos model. 2024-02-14 22:16:09 +01:00
94 changed files with 1179 additions and 10450 deletions

View File

@ -19,7 +19,7 @@ exclude = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.4.1" version = "0.4.0"
edition = "2021" edition = "2021"
description = "Minimalist ML framework." description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle" repository = "https://github.com/huggingface/candle"
@ -31,18 +31,17 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" } accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.4.1" } candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
candle-datasets = { path = "./candle-datasets", version = "0.4.1" } candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.1" } candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
candle-kernels = { path = "./candle-kernels", version = "0.4.1" } candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.1" } candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
candle-nn = { path = "./candle-nn", version = "0.4.1" } candle-nn = { path = "./candle-nn", version = "0.4.0" }
candle-onnx = { path = "./candle-onnx", version = "0.4.1" } candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
candle-transformers = { path = "./candle-transformers", version = "0.4.1" } candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] } cudarc = { version = "0.10.0", features = ["f16"] }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0" hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }

View File

@ -63,8 +63,6 @@ We also provide a some command line based examples using state of the art models
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant. the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM. - [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
Deepmind.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports pre-trained on 1T tokens of English and code datasets. Also supports
@ -76,10 +74,9 @@ We also provide a some command line based examples using state of the art models
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of - [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
experts 8x7b general LLM with better performance than a Llama 2 70B model with experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference. much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/) and - [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs. - [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM - [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
performance. performance.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion. - [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual - [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
@ -110,12 +107,7 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200"> <img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
- [Whisper](./candle-examples/examples/whisper/): speech recognition model. - [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
model using residual vector quantization.
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
text-to-speech.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/), - [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings. [JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained - [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
@ -195,10 +187,9 @@ If you have an addition to this list, please submit a pull request.
- Language Models. - Language Models.
- LLaMA v1 and v2 with variants such as SOLAR-10.7B. - LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- Falcon. - Falcon.
- StarCoder, StarCoder2. - StarCoder.
- Phi 1, 1.5, and 2. - Phi 1, 1.5, and 2.
- Mamba, Minimal Mamba - Mamba, Minimal Mamba
- Gemma 2b and 7b.
- Mistral 7b v0.1. - Mistral 7b v0.1.
- Mixtral 8x7b v0.1. - Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B. - StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
@ -206,7 +197,7 @@ If you have an addition to this list, please submit a pull request.
- Bert. - Bert.
- Yi-6B and Yi-34B. - Yi-6B and Yi-34B.
- Qwen1.5. - Qwen1.5.
- RWKV v5 and v6. - RWKV.
- Quantized LLMs. - Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants. - Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct. - Mistral 7b, and 7b instruct.
@ -216,22 +207,18 @@ If you have an addition to this list, please submit a pull request.
- Text to text. - Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction). - T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation). - Marian MT (Machine Translation).
- Whisper (multi-lingual support).
- Text to image. - Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0. - Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2. - Wurstchen v2.
- Image to text. - Image to text.
- BLIP. - BLIP.
- TrOCR. - TrOCR.
- Audio.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Computer Vision Models. - Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2, MobileOne, EfficientVit (MSRA). ConvNeXTv2.
- yolo-v3, yolo-v8. - yolo-v3, yolo-v8.
- Segment-Anything Model (SAM). - Segment-Anything Model (SAM).
- SegFormer.
- File formats: load models from safetensors, npz, ggml, or PyTorch files. - File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments. - Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types. - Quantization support using the llama.cpp quantized types.

View File

@ -5,32 +5,25 @@ extern crate accelerate_src;
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::Result; use anyhow::Result;
use candle_core::{Device, Module, Tensor}; use candle_core::{Device, Tensor};
use candle_core::quantized::{QMatMul, QTensor};
fn main() -> Result<()> { fn main() -> Result<()> {
let device = Device::new_cuda(0)?; let device = Device::new_cuda(0)?;
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?; let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?; let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?; let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let q = QMatMul::from_qtensor(q)?; println!("{out_t}");
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?; let in_t = in_t.to_device(&Device::Cpu)?;
let res_q_cuda = q.forward(&x)?; let k_t = k_t.to_device(&Device::Cpu)?;
println!("{res_q_cuda}"); let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?; .sqr()?
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?; .sum_all()?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;
let res_q_cpu = q_cpu.forward(&x_cpu)?;
println!("{res_q_cpu}");
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
.abs()?
.flatten_all()?
.max(0)?;
println!("{diff}"); println!("{diff}");
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
let res = t.conv2d(&w, 1, 1, 1, 1)?;
println!("{res:?}");
Ok(()) Ok(())
} }

View File

@ -113,7 +113,7 @@ impl Tensor {
| Op::Unary(_node, UnaryOp::Floor) | Op::Unary(_node, UnaryOp::Floor)
| Op::Unary(_node, UnaryOp::Round) => nodes, | Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node) Op::Reshape(node)
| Op::UpsampleNearest1D { arg: node, .. } | Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D { arg: node, .. } | Op::UpsampleNearest2D { arg: node, .. }
| Op::AvgPool2D { arg: node, .. } | Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. } | Op::MaxPool2D { arg: node, .. }
@ -250,7 +250,6 @@ impl Tensor {
out_padding, out_padding,
*stride, *stride,
*dilation, *dilation,
/* groups */ 1,
)?; )?;
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?; *sum_grad = sum_grad.add(&grad_arg)?;
@ -348,18 +347,9 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?; let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?; *sum_grad = sum_grad.add(&grad_arg)?;
} }
Op::UpsampleNearest1D { arg, target_size } => { Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
let (_n, c, size) = arg.dims3()?; op: "upsample-nearest1d",
if target_size % size != 0 { })?,
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale = target_size / size;
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::UpsampleNearest2D { Op::UpsampleNearest2D {
arg, arg,
target_h, target_h,

View File

@ -187,16 +187,36 @@ impl Tensor {
} }
} }
fn conv_transpose1d_single_group( /// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self, &self,
kernel: &Self, kernel: &Self,
params: &ParamsConvTranspose1D, padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
) -> Result<Self> { ) -> Result<Self> {
let (b_size, c_in, l_in) = self.dims3()?;
let (c_in_k, c_out, k_size) = kernel.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in,
padding,
output_padding,
stride,
dilation,
};
let storage = self.storage().conv_transpose1d( let storage = self.storage().conv_transpose1d(
self.layout(), self.layout(),
&kernel.storage(), &kernel.storage(),
kernel.layout(), kernel.layout(),
params, &params,
)?; )?;
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D { let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
arg, arg,
@ -210,49 +230,6 @@ impl Tensor {
Ok(crate::tensor::from_storage(storage, out_dims, op, false)) Ok(crate::tensor::from_storage(storage, out_dims, op, false))
} }
/// Applies a 1D transposed convolution over the input tensor.
pub fn conv_transpose1d(
&self,
kernel: &Self,
padding: usize,
output_padding: usize,
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
let (c_in_k, c_out, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
if c_in != c_in_k {
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
}
if c_in % groups != 0 {
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
}
let params = ParamsConvTranspose1D {
b_size,
l_in,
k_size,
c_out,
c_in: c_in / groups,
padding,
output_padding,
stride,
dilation,
};
if groups == 1 {
self.conv_transpose1d_single_group(kernel, &params)
} else {
let blocks = self.chunk(groups, 1)?;
let kernel = kernel.chunk(groups, 0)?;
let blocks = blocks
.iter()
.zip(&kernel)
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, &params))
.collect::<Result<Vec<_>>>()?;
Tensor::cat(&blocks, 1)
}
}
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> { fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
let storage = let storage =
self.storage() self.storage()

View File

@ -1263,7 +1263,6 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> { fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
let p = self.0; let p = self.0;
let inp = &inp[inp_l.start_offset()..]; let inp = &inp[inp_l.start_offset()..];
let k = &k[k_l.start_offset()..];
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?; let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?; let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
let l_out = p.l_out(); let l_out = p.l_out();
@ -2575,7 +2574,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
} }
} }
@ -2584,7 +2583,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
} }
} }
@ -2601,7 +2600,7 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
} }
} }

View File

@ -129,15 +129,6 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
} }
} }
impl<M: Module> Module for Option<&M> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
None => Ok(xs.clone()),
Some(m) => m.forward(xs),
}
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to // A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors. // separate the training and evaluation behaviors.
pub trait ModuleT { pub trait ModuleT {

View File

@ -738,7 +738,6 @@ impl BackendStorage for MetalStorage {
("ufloor", DType::F32) => strided::floor::FLOAT, ("ufloor", DType::F32) => strided::floor::FLOAT,
("urelu", DType::F32) => strided::relu::FLOAT, ("urelu", DType::F32) => strided::relu::FLOAT,
("uround", DType::F32) => strided::round::FLOAT, ("uround", DType::F32) => strided::round::FLOAT,
("utanh", DType::F32) => strided::tanh::FLOAT,
("ucos", DType::F16) => strided::cos::HALF, ("ucos", DType::F16) => strided::cos::HALF,
("usin", DType::F16) => strided::sin::HALF, ("usin", DType::F16) => strided::sin::HALF,
("usqr", DType::F16) => strided::sqr::HALF, ("usqr", DType::F16) => strided::sqr::HALF,
@ -755,7 +754,6 @@ impl BackendStorage for MetalStorage {
("ufloor", DType::F16) => strided::floor::HALF, ("ufloor", DType::F16) => strided::floor::HALF,
("urelu", DType::F16) => strided::relu::HALF, ("urelu", DType::F16) => strided::relu::HALF,
("uround", DType::F16) => strided::round::HALF, ("uround", DType::F16) => strided::round::HALF,
("utanh", DType::F16) => strided::tanh::HALF,
(name, dtype) => { (name, dtype) => {
crate::bail!("Metal strided unary {name} {dtype:?} not implemented") crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
} }
@ -829,9 +827,9 @@ impl BackendStorage for MetalStorage {
layout.start_offset() * self.dtype.size_in_bytes(), layout.start_offset() * self.dtype.size_in_bytes(),
), ),
&t.buffer, &t.buffer,
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()), (&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer, &f.buffer,
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()), (&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -1266,7 +1264,7 @@ impl BackendStorage for MetalStorage {
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length); blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
blit.end_encoding(); blit.end_encoding();
} else { } else {
let src_shape = src_l.shape(); let src_shape = src_l.shape();
@ -1638,7 +1636,7 @@ impl BackendDevice for MetalDevice {
min as f32, min as f32,
max as f32, max as f32,
shape.elem_count(), shape.elem_count(),
&self.seed.lock().unwrap(), &*self.seed.lock().unwrap(),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
@ -1669,7 +1667,7 @@ impl BackendDevice for MetalDevice {
mean as f32, mean as f32,
stddev as f32, stddev as f32,
shape.elem_count(), shape.elem_count(),
&self.seed.lock().unwrap(), &*self.seed.lock().unwrap(),
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;

View File

@ -132,10 +132,7 @@ pub enum Op {
stride: (usize, usize), stride: (usize, usize),
}, },
UpsampleNearest1D { UpsampleNearest1D(Tensor),
arg: Tensor,
target_size: usize,
},
UpsampleNearest2D { UpsampleNearest2D {
arg: Tensor, arg: Tensor,
target_h: usize, target_h: usize,

View File

@ -42,7 +42,7 @@ pub enum OpCode {
Stop = b'.', Stop = b'.',
NewObj = 0x81, NewObj = 0x81,
EmptyList = b']', EmptyList = b']',
BinFloat = b'G', BinFloat = b'g',
Append = b'a', Append = b'a',
Appends = b'e', Appends = b'e',
} }
@ -462,10 +462,7 @@ impl Stack {
self.push(Object::Int(arg)) self.push(Object::Int(arg))
} }
OpCode::BinFloat => { OpCode::BinFloat => {
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian. let arg = r.read_f64::<LittleEndian>()?;
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
let arg = r.read_f64::<byteorder::BigEndian>()?;
self.push(Object::Float(arg)) self.push(Object::Float(arg))
} }
OpCode::BinUnicode => { OpCode::BinUnicode => {

View File

@ -1,343 +0,0 @@
use super::{GgmlDType, QStorage};
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use cudarc::driver::{CudaSlice, DeviceSlice};
pub struct QCudaStorage {
data: CudaSlice<u8>,
dtype: GgmlDType,
device: CudaDevice,
}
pub const WARP_SIZE: usize = 32;
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
pub const NWARPS_Q4_0_AMPERE: usize = 4;
pub const GGML_CUDA_MMV_X: usize = 32;
pub const GGML_CUDA_MMV_Y: usize = 1;
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
fn dequantize(
data: &CudaSlice<u8>,
dtype: GgmlDType,
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
GgmlDType::Q5_0 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_0",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q5_1 => {
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
(
"dequantize_block_q5_1",
false,
CUDA_DEQUANTIZE_BLOCK_SIZE,
nb,
)
}
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, 1, 1),
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
if is_k {
let params = (data, &dst);
unsafe { func.launch(cfg, params) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (data, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
fn dequantize_mut_mal_vec(
data: &CudaSlice<u8>,
y: &cudarc::driver::CudaView<f32>,
dtype: GgmlDType,
ncols: usize,
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let kernel_name = match dtype {
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
shared_mem_bytes: 0,
};
let params = (data, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
impl QCudaStorage {
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
Ok(QCudaStorage {
data,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
let fast_kernel = matches!(
self.dtype,
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q2K
| GgmlDType::Q3K
| GgmlDType::Q4K
| GgmlDType::Q5K
| GgmlDType::Q6K
| GgmlDType::Q8K
);
if fast_kernel {
return dequantize(&self.data, self.dtype, elem_count, self.device());
}
// Run the dequantization on cpu.
use crate::quantized::k_quants::GgmlType;
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
GgmlDType::F32 => {
let slice =
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
out.copy_from_slice(slice)
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
self.device
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
}
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.dtoh_sync_copy(data).w()?
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?;
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.data = data;
Ok(())
}
pub fn storage_size_in_bytes(&self) -> usize {
self.data.len()
}
pub fn fwd(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
self.dequantize_matmul_vec(self_shape, storage, layout)
} else {
self.dequantize_matmul(self_shape, storage, layout)
}
}
}
impl QCudaStorage {
fn dequantize_matmul_vec(
&self,
self_shape: &crate::Shape,
rhs: &CudaStorage,
rhs_l: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
let (nrows, ncols) = self_shape.dims2()?;
let rhs = rhs.as_cuda_slice::<f32>()?;
let rhs = match rhs_l.contiguous_offsets() {
Some((o1, o2)) => rhs.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
};
let (with_batch, k) = match rhs_l.shape().dims() {
[1, 1, k] => (true, k),
[1, k] => (false, k),
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
};
if ncols != *k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
}
let out =
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
let out_shape = if with_batch {
vec![1, 1, nrows]
} else {
vec![1, nrows]
};
Ok((out, out_shape.into()))
}
fn dequantize_matmul(
&self,
self_shape: &crate::Shape,
storage: &CudaStorage,
layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
use crate::backend::BackendStorage;
let (n, k) = self_shape.dims2()?;
let (b, m, k2) = match layout.shape().dims() {
&[b, m, k2] => (b, m, k2),
&[m, k2] => (1, m, k2),
s => crate::bail!("unexpected shape for input {s:?}"),
};
if k2 != k {
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
}
let data_f32 = self.dequantize(n * k)?;
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0);
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
let mut out_shape = layout.shape().dims().to_vec();
out_shape.pop();
out_shape.push(n);
Ok((out, out_shape.into()))
}
}
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
slice.to_vec()
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
device: &CudaDevice,
data: &[T],
) -> Result<super::QStorage> {
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
};
let data = device.htod_sync_copy(data).w()?;
Ok(QStorage::Cuda(QCudaStorage {
data,
device: device.clone(),
dtype: T::DTYPE,
}))
}

View File

@ -1,50 +0,0 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{CudaDevice, CudaStorage, Error, Result};
pub struct QCudaStorage {
dtype: GgmlDType,
device: CudaDevice,
}
impl QCudaStorage {
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &CudaDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
Err(Error::NotCompiledWithCudaSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &CudaStorage,
_layout: &crate::Layout,
) -> Result<(CudaStorage, crate::Shape)> {
Err(Error::NotCompiledWithCudaSupport)
}
}
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &CudaDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -41,10 +41,3 @@ impl QMetalStorage {
Err(Error::NotCompiledWithMetalSupport) Err(Error::NotCompiledWithMetalSupport)
} }
} }
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
_device: &MetalDevice,
_data: &[T],
) -> Result<super::QStorage> {
Err(Error::NotCompiledWithMetalSupport)
}

View File

@ -1,5 +1,7 @@
//! Support for the GGML file format. //! Support for the GGML file format.
#[cfg(feature = "metal")]
use super::metal::load_quantized_metal;
use super::{k_quants, GgmlDType, QStorage}; use super::{k_quants, GgmlDType, QStorage};
use crate::{Device, Result}; use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
@ -128,8 +130,13 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device { let data: QStorage = match device {
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())), Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
Device::Metal(metal) => super::metal::load_quantized(metal, data)?, #[cfg(feature = "metal")]
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?, Device::Metal(metal) => load_quantized_metal(metal, data)?,
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal backend requires `metal` feature")
}
device => unimplemented!("Implement quantized tensor for device {device:?}"),
}; };
super::QTensor::new(data, dims) super::QTensor::new(data, dims)
} }

View File

@ -34,8 +34,6 @@ impl QMetalStorage {
} }
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> { pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
use crate::quantized::k_quants::GgmlType;
let buffer = self.device.new_buffer_managed(self.buffer.length())?; let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu"); command_buffer.set_label("to_cpu");
@ -45,62 +43,81 @@ impl QMetalStorage {
blit.end_encoding(); blit.end_encoding();
self.device.wait_until_completed()?; self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count]; let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype { match self.dtype {
GgmlDType::F32 => { GgmlDType::F32 => {
let vec: Vec<f32> = read_to_vec(&buffer, block_len); let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
f32::to_float(&vec, &mut out)?; f32::to_float(&vec, &mut out)?;
} }
GgmlDType::F16 => { GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len); let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
half::f16::to_float(&vec, &mut out)?; half::f16::to_float(&vec, &mut out)?;
} }
GgmlDType::Q4_0 => { GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?; crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
} }
GgmlDType::Q4_1 => { GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?; crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
} }
GgmlDType::Q5_0 => { GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?; crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
} }
GgmlDType::Q5_1 => { GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?; crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
} }
GgmlDType::Q8_0 => { GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?; crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
} }
GgmlDType::Q8_1 => { GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?; crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
} }
GgmlDType::Q2K => { GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ2K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?; crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q3K => { GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ3K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?; crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q4K => { GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ4K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?; crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q5K => { GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ5K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?; crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q6K => { GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ6K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?; crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
} }
GgmlDType::Q8K => { GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len); let vec: Vec<crate::quantized::BlockQ8K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
} }
} }
@ -175,7 +192,7 @@ impl QMetalStorage {
} }
} }
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>( pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
device: &MetalDevice, device: &MetalDevice,
data: &[T], data: &[T],
) -> Result<QStorage> { ) -> Result<QStorage> {

View File

@ -4,7 +4,6 @@ use std::borrow::Cow;
#[cfg(target_feature = "avx")] #[cfg(target_feature = "avx")]
pub mod avx; pub mod avx;
mod dummy_cuda;
mod dummy_metal; mod dummy_metal;
pub mod ggml_file; pub mod ggml_file;
pub mod gguf_file; pub mod gguf_file;
@ -15,13 +14,6 @@ pub mod metal;
mod metal { mod metal {
pub use super::dummy_metal::*; pub use super::dummy_metal::*;
} }
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(not(feature = "cuda"))]
mod cuda {
pub use super::dummy_cuda::*;
}
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
pub mod neon; pub mod neon;
#[cfg(target_feature = "simd128")] #[cfg(target_feature = "simd128")]
@ -47,9 +39,8 @@ impl Device {
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?; let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
Ok(QStorage::Metal(storage)) Ok(QStorage::Metal(storage))
} }
Device::Cuda(cuda) => { Device::Cuda(_cuda) => {
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?; crate::bail!("Cuda ggml quantization not supported");
Ok(QStorage::Cuda(storage))
} }
} }
} }
@ -58,7 +49,6 @@ impl Device {
pub enum QStorage { pub enum QStorage {
Cpu(Box<dyn QuantizedType>), Cpu(Box<dyn QuantizedType>),
Metal(metal::QMetalStorage), Metal(metal::QMetalStorage),
Cuda(cuda::QCudaStorage),
} }
impl QStorage { impl QStorage {
@ -66,7 +56,6 @@ impl QStorage {
match self { match self {
QStorage::Cpu(storage) => storage.block_size(), QStorage::Cpu(storage) => storage.block_size(),
QStorage::Metal(storage) => storage.dtype().block_size(), QStorage::Metal(storage) => storage.dtype().block_size(),
QStorage::Cuda(storage) => storage.dtype().block_size(),
} }
} }
@ -74,7 +63,6 @@ impl QStorage {
match self { match self {
QStorage::Cpu(storage) => storage.dtype(), QStorage::Cpu(storage) => storage.dtype(),
QStorage::Metal(storage) => storage.dtype(), QStorage::Metal(storage) => storage.dtype(),
QStorage::Cuda(storage) => storage.dtype(),
} }
} }
@ -82,7 +70,6 @@ impl QStorage {
match self { match self {
QStorage::Cpu(_storage) => Device::Cpu, QStorage::Cpu(_storage) => Device::Cpu,
QStorage::Metal(storage) => Device::Metal(storage.device().clone()), QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
} }
} }
@ -90,7 +77,6 @@ impl QStorage {
match self { match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(), QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
QStorage::Metal(storage) => storage.storage_size_in_bytes(), QStorage::Metal(storage) => storage.storage_size_in_bytes(),
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
} }
} }
@ -100,7 +86,6 @@ impl QStorage {
storage.from_float(src.as_slice::<f32>()?)?; storage.from_float(src.as_slice::<f32>()?)?;
} }
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"), _ => crate::bail!("Invalid dequantize storage locations do not match"),
} }
Ok(()) Ok(())
@ -110,7 +95,6 @@ impl QStorage {
match self { match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
} }
} }
@ -122,7 +106,7 @@ impl QStorage {
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data)) Ok(Cow::from(data))
} }
QStorage::Metal(_) | QStorage::Cuda(_) => { QStorage::Metal(_storage) => {
crate::bail!("not implemented"); crate::bail!("not implemented");
} }
} }
@ -440,7 +424,7 @@ impl crate::CustomOp1 for QTensor {
#[allow(clippy::infallible_destructuring_match)] #[allow(clippy::infallible_destructuring_match)]
let self_storage = match &self.storage { let self_storage = match &self.storage {
QStorage::Cpu(storage) => storage, QStorage::Cpu(storage) => storage,
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"), QStorage::Metal(_) => crate::bail!("Invalid storage"),
}; };
let slice = storage.as_slice::<f32>()?; let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
@ -460,18 +444,6 @@ impl crate::CustomOp1 for QTensor {
}; };
self_storage.fwd(&self.shape, storage, layout) self_storage.fwd(&self.shape, storage, layout)
} }
fn cuda_fwd(
&self,
storage: &crate::CudaStorage,
layout: &crate::Layout,
) -> Result<(crate::CudaStorage, Shape)> {
let self_storage = match &self.storage {
QStorage::Cuda(cuda) => cuda,
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
};
self_storage.fwd(&self.shape, storage, layout)
}
} }
impl crate::Module for QMatMul { impl crate::Module for QMatMul {

View File

@ -352,10 +352,6 @@ impl Storage {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?; let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s)) Ok(Self::Cuda(s))
} }
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(), lhs: lhs.device().location(),
rhs: rhs.device().location(), rhs: rhs.device().location(),

View File

@ -1015,7 +1015,7 @@ impl Tensor {
/// tensor also has three dimensions, `(batch, channels, target_size)`. /// tensor also has three dimensions, `(batch, channels, target_size)`.
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> { pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
let (n, c, _l) = self.dims3()?; let (n, c, _l) = self.dims3()?;
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size }); let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
let storage = self let storage = self
.storage() .storage()
.upsample_nearest1d(self.layout(), target_size)?; .upsample_nearest1d(self.layout(), target_size)?;

View File

@ -18,9 +18,6 @@ w_t = w.transpose(0, 1)
res = torch.nn.functional.conv_transpose1d(t, w_t) res = torch.nn.functional.conv_transpose1d(t, w_t)
print(res.shape) print(res.shape)
print(res) print(res)
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
print(res.shape)
print(res)
*/ */
fn conv1d(dev: &Device) -> Result<()> { fn conv1d(dev: &Device) -> Result<()> {
let t = Tensor::new( let t = Tensor::new(
@ -53,7 +50,7 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
); );
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?; let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]); assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!( assert_eq!(
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
@ -62,17 +59,6 @@ fn conv1d(dev: &Device) -> Result<()> {
4.7076, -5.9745, -0.8276, 1.621 4.7076, -5.9745, -0.8276, 1.621
], ],
); );
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
assert_eq!(res.dims(), [1, 4, 7]);
assert_eq!(
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
[
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
]
);
Ok(()) Ok(())
} }

View File

@ -283,38 +283,19 @@ fn unary_grad(device: &Device) -> Result<()> {
[1.0881, 0.9277, 1.0527, 0.5747], [1.0881, 0.9277, 1.0527, 0.5747],
); );
if device.is_cpu() {
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
let y = x.interpolate1d(12)?.reshape(36)?;
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
33., 34., 35., 36.,
],
device,
)?;
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(grad_x, 4)?,
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
);
}
// manually checked: see comments // manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?; let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?; let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new( let z = Tensor::new(
&[ &[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., 1_f32, 02., 03., 04., 05., 06.,
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 07., 08., 09., 10., 11., 12.,
35., 36., 13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
], ],
device, device,
)?; )?;
@ -345,11 +326,15 @@ fn unary_grad(device: &Device) -> Result<()> {
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?; let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?; let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new( let z = Tensor::new(
&[ &[
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17., 1_f32, 02., 03., 04., 05., 06.,
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 07., 08., 09., 10., 11., 12.,
35., 36., 13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
], ],
device, device,
)?; )?;

View File

@ -178,6 +178,10 @@ test_device!(
); );
fn quantize_q4_0(device: &Device) -> Result<()> { fn quantize_q4_0(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
@ -205,6 +209,10 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
} }
fn quantize_q4_1(device: &Device) -> Result<()> { fn quantize_q4_1(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
@ -231,6 +239,10 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
} }
fn quantize_q5_0(device: &Device) -> Result<()> { fn quantize_q5_0(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
@ -257,6 +269,10 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
} }
fn quantize_q5_1(device: &Device) -> Result<()> { fn quantize_q5_1(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let src = Tensor::from_slice(&src, (32 * 4,), device)?;
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
@ -357,6 +373,10 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
} }
fn quantize_q2k(device: &Device) -> Result<()> { fn quantize_q2k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q2K; let dtype = GgmlDType::Q2K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
@ -391,6 +411,10 @@ fn quantize_q2k(device: &Device) -> Result<()> {
} }
fn quantize_q3k(device: &Device) -> Result<()> { fn quantize_q3k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q3K; let dtype = GgmlDType::Q3K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -424,6 +448,10 @@ fn quantize_q3k(device: &Device) -> Result<()> {
} }
fn quantize_q4k(device: &Device) -> Result<()> { fn quantize_q4k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q4K; let dtype = GgmlDType::Q4K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -457,6 +485,10 @@ fn quantize_q4k(device: &Device) -> Result<()> {
} }
fn quantize_q5k(device: &Device) -> Result<()> { fn quantize_q5k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q5K; let dtype = GgmlDType::Q5K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -490,6 +522,10 @@ fn quantize_q5k(device: &Device) -> Result<()> {
} }
fn quantize_q6k(device: &Device) -> Result<()> { fn quantize_q6k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q6K; let dtype = GgmlDType::Q6K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -523,6 +559,10 @@ fn quantize_q6k(device: &Device) -> Result<()> {
} }
fn quantize_q8k(device: &Device) -> Result<()> { fn quantize_q8k(device: &Device) -> Result<()> {
// TODO Enable this later when we enable cuda.
if device.is_cuda() {
return Ok(());
}
let dtype = GgmlDType::Q8K; let dtype = GgmlDType::Q8K;
let src = get_test_vector2(0.5, 1024, device)?; let src = get_test_vector2(0.5, 1024, device)?;
let quant = quantized::QTensor::quantize(&src, dtype)?; let quant = quantized::QTensor::quantize(&src, dtype)?;
@ -738,6 +778,10 @@ macro_rules! quantized_matmul {
// stable. https://github.com/rust-lang/rust/issues/29599 // stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => { ($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
fn $fn_name(device: &Device) -> Result<()> { fn $fn_name(device: &Device) -> Result<()> {
if device.is_cuda() {
// TODO Enable Cuda GGML sometime maybe.
return Ok(());
}
test_matmul(device, (1, 3, 4, 256), $dtype)?; test_matmul(device, (1, 3, 4, 256), $dtype)?;
Ok(()) Ok(())
} }

View File

@ -12,7 +12,7 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { workspace = true } candle = { workspace = true }
candle-datasets = { workspace = true, optional = true } candle-datasets = { workspace = true }
candle-nn = { workspace = true } candle-nn = { workspace = true }
candle-transformers = { workspace = true } candle-transformers = { workspace = true }
candle-flash-attn = { workspace = true, optional = true } candle-flash-attn = { workspace = true, optional = true }
@ -30,7 +30,7 @@ rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true } symphonia = { version = "0.5.3", features = ["all"] }
tokenizers = { workspace = true, features = ["onig"] } tokenizers = { workspace = true, features = ["onig"] }
cpal= { version = "0.15.2", optional = true } cpal= { version = "0.15.2", optional = true }
@ -80,26 +80,6 @@ required-features = ["onnx"]
name = "onnx_basics" name = "onnx_basics"
required-features = ["onnx"] required-features = ["onnx"]
[[example]]
name = "whisper"
required-features = ["symphonia"]
[[example]] [[example]]
name = "whisper-microphone" name = "whisper-microphone"
required-features = ["microphone"] required-features = ["microphone"]
[[example]]
name = "mnist-training"
required-features = ["candle-datasets"]
[[example]]
name = "llama2-c"
required-features = ["candle-datasets"]
[[example]]
name = "encodec"
required-features = ["symphonia"]
[[example]]
name = "metavoice"
required-features = ["symphonia"]

View File

@ -1 +0,0 @@
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));

View File

@ -1,20 +0,0 @@
# candle-efficientvit
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://arxiv.org/abs/2305.07027).
This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example efficientvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 69.80%
unicycle, monocycle : 13.03%
bicycle-built-for-two, tandem bicycle, tandem: 9.28%
crash helmet : 2.25%
alp : 0.46%
```

View File

@ -1,99 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::efficientvit;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
M0,
M1,
M2,
M3,
M4,
M5,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::M0 => "m0",
Self::M1 => "m1",
Self::M2 => "m2",
Self::M3 => "m3",
Self::M4 => "m4",
Self::M5 => "m5",
};
format!("timm/efficientvit_{}.r224_in1k", name)
}
fn config(&self) -> efficientvit::Config {
match self {
Self::M0 => efficientvit::Config::m0(),
Self::M1 => efficientvit::Config::m1(),
Self::M2 => efficientvit::Config::m2(),
Self::M3 => efficientvit::Config::m3(),
Self::M4 => efficientvit::Config::m4(),
Self::M5 => efficientvit::Config::m5(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::M0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = efficientvit::efficientvit(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,20 +0,0 @@
# candle-endocec
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
compression model using an encoder/decoder architecture with residual vector
quantization.
## Running one example
```bash
cargo run --example encodec --features symphonia --release -- code-to-audio \
candle-examples/examples/encodec/jfk-codes.safetensors \
jfk.wav
```
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
an output wav file containing the audio data. Instead of `code-to-audio` one
can use:
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
containing EnCodec tokens for the input audio file.

View File

@ -1,143 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::encodec::{Config, Model};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Action {
AudioToAudio,
AudioToCode,
CodeToAudio,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// The action to be performed, specifies the format for the input and output data.
action: Action,
/// The input file, either an audio file or some encodec tokens stored as safetensors.
in_file: String,
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
out_file: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// The model weight file, in safetensor format.
#[arg(long)]
model: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let config = Config::default();
let model = Model::new(&config, vb)?;
let codes = match args.action {
Action::CodeToAudio => {
let codes = candle::safetensors::load(args.in_file, &device)?;
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
codes
}
Action::AudioToCode | Action::AudioToAudio => {
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
if sample_rate != 24_000 {
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
}
let pcm_len = pcm.len();
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
println!("input pcm shape: {:?}", pcm.shape());
model.encode(&pcm)?
}
};
println!("codes shape: {:?}", codes.shape());
match args.action {
Action::AudioToCode => {
codes.save_safetensors("codes", &args.out_file)?;
}
Action::AudioToAudio | Action::CodeToAudio => {
let pcm = model.decode(&codes)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
}
}
Ok(())
}

View File

@ -1,27 +0,0 @@
# candle-mistral: 2b and 7b LLMs from Google DeepMind
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
models published by Google Deepmind with a 2b and a 7b variant.
In order to use the example below, you have to accept the license on the
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
your access token via the [HuggingFace cli login
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
## Running the example
```bash
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
fn count_primes(max_n: usize) -> usize {
let mut primes = vec![true; max_n];
for i in 2..=max_n {
if primes[i] {
for j in i * i..max_n {
primes[j] = false;
}
}
}
primes.len()
}
```

View File

@ -1,256 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::gemma::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<eos>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <eos> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => match model_id.as_str() {
"7b-it" => "google/gemma-7b-it".to_string(),
"7b" => "google/gemma-7b".to_string(),
"2b-it" => "google/gemma-2b-it".to_string(),
"2b" => "google/gemma-2b".to_string(),
_ => model_id.to_string(),
},
None => "google/gemma-2b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -57,7 +57,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 10000)] #[arg(long, default_value_t = 100)]
sample_len: usize, sample_len: usize,
/// Disable the key-value cache. /// Disable the key-value cache.
@ -120,7 +120,7 @@ fn main() -> Result<()> {
Some(dtype) => bail!("Unsupported dtype {dtype}"), Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16, None => DType::F16,
}; };
let (llama, tokenizer_filename, mut cache) = { let (llama, tokenizer_filename, cache) = {
let api = Api::new()?; let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which { let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(), Which::V1 => "Narsil/amall-7b".to_string(),
@ -143,10 +143,11 @@ fn main() -> Result<()> {
} }
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?], Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
}; };
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &config)?, tokenizer_filename, cache) (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
}; };
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN); let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
@ -156,7 +157,6 @@ fn main() -> Result<()> {
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
println!("starting the inference loop"); println!("starting the inference loop");
print!("{prompt}"); print!("{prompt}");
@ -172,7 +172,7 @@ fn main() -> Result<()> {
}; };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, context_index, &mut cache)?; let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {
logits logits
@ -190,16 +190,18 @@ fn main() -> Result<()> {
token_generated += 1; token_generated += 1;
tokens.push(next_token); tokens.push(next_token);
// Extracting the last token as a string is complicated, here we just apply some simple
// heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?;
}
if Some(next_token) == eos_token_id { if Some(next_token) == eos_token_id {
break; break;
} }
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
} }
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(

View File

@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
use std::io::Write; use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use model::{Cache, Config, Llama}; use model::{Config, Llama};
use qmodel::QLlama; use qmodel::QLlama;
use weights::TransformerWeights; use weights::TransformerWeights;
@ -160,10 +160,10 @@ enum Model {
} }
impl Model { impl Model {
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> { fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
match self { match self {
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?), Self::Llama(l) => Ok(l.forward(xs, pos)?),
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?), Self::QLlama(l) => Ok(l.forward(xs, pos)?),
} }
} }
} }
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?; let vb = weights.var_builder(&config, &device)?;
let mut cache = Cache::new(false, &config, vb.pp("rot"))?; let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?; let model = Llama::load(vb, &cache, config)?;
let tokens = match &args.pretokenized_dir { let tokens = match &args.pretokenized_dir {
None => { None => {
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
for inp_tgt in batch_iter { for inp_tgt in batch_iter {
let (inp, tgt) = inp_tgt?; let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0, &mut cache)?; let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
println!("{}", loss.to_vec0::<f32>()?); println!("{}", loss.to_vec0::<f32>()?);
} }
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let is_safetensors = config_path let is_safetensors = config_path
.extension() .extension()
.map_or(false, |v| v == "safetensors"); .map_or(false, |v| v == "safetensors");
let (model, config, mut cache) = if is_gguf { let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
let (_vocab_size, dim) = vb let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")? .get_no_shape("model.embed_tokens.weight")?
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
&device, &device,
); );
let cache = model::Cache::new(true, &config, fake_vb)?; let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, config.clone())?); let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config, cache) (model, config)
} else if is_safetensors { } else if is_safetensors {
let config = Config::tiny_15m(); let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?; let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?); let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config, cache) (model, config)
} else { } else {
let mut file = std::fs::File::open(config_path)?; let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?; let config = Config::from_reader(&mut file)?;
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?; let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, config.clone())?); let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config, cache) (model, config)
}; };
println!("starting the inference loop"); println!("starting the inference loop");
@ -328,7 +328,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0.. { for index in 0.. {
@ -338,7 +337,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let context_size = if index > 0 { 1 } else { tokens.len() }; let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos, &mut cache)?; let logits = model.forward(&input, index_pos)?;
let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = logits.i((0, logits.dim(1)? - 1))?;
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
logits logits
@ -354,14 +353,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let next_token = logits_processor.sample(&logits)?; let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
if let Some(t) = tokenizer.next_token(next_token)? { // Extracting the last token as a string is complicated, here we just apply some simple
print!("{t}"); // heuristics as it seems to work well enough for this example. See the following for more
// details:
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
if let Some(text) = tokenizer.id_to_token(next_token) {
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
print!("{text}");
std::io::stdout().flush()?; std::io::stdout().flush()?;
} }
} }
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
let dt = start_gen.elapsed(); let dt = start_gen.elapsed();
println!( println!(
"\n{} tokens generated ({:.2} token/s)\n", "\n{} tokens generated ({:.2} token/s)\n",

View File

@ -8,7 +8,6 @@ fn valid_loss(
model: &Llama, model: &Llama,
args: &crate::TrainingCmd, args: &crate::TrainingCmd,
device: &Device, device: &Device,
cache: &mut Cache,
) -> Result<f64> { ) -> Result<f64> {
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone()); let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
@ -16,7 +15,7 @@ fn valid_loss(
let mut cnt = 0usize; let mut cnt = 0usize;
for inp_tgt in batch_iter.take(50) { for inp_tgt in batch_iter.take(50) {
let (inp, tgt) = inp_tgt?; let (inp, tgt) = inp_tgt?;
let logits = model.forward(&inp, 0, cache)?; let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
sum_ce += loss.to_vec0::<f32>()? as f64; sum_ce += loss.to_vec0::<f32>()? as f64;
cnt += 1; cnt += 1;
@ -38,8 +37,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
let mut cache = Cache::new(false, &config, vb.pp("rot"))?; let cache = Cache::new(false, &config, vb.pp("rot"))?;
let model = Llama::load(vb, config)?; let model = Llama::load(vb, &cache, config)?;
let params = candle_nn::ParamsAdamW { let params = candle_nn::ParamsAdamW {
lr: args.learning_rate, lr: args.learning_rate,
..Default::default() ..Default::default()
@ -47,14 +46,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?; let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
for (batch_index, batch) in batch_iter.enumerate() { for (batch_index, batch) in batch_iter.enumerate() {
let (inp, tgt) = batch?; let (inp, tgt) = batch?;
let logits = model.forward(&inp, 0, &mut cache)?; let logits = model.forward(&inp, 0)?;
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
opt.backward_step(&loss)?; opt.backward_step(&loss)?;
if batch_index > 0 && batch_index % 100 == 0 { if batch_index > 0 && batch_index % 100 == 0 {
// TODO: Add a way to deactivate the backprop graph tracking when computing the // TODO: Add a way to deactivate the backprop graph tracking when computing the
// validation loss. // validation loss.
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?; let loss = valid_loss(&dataset, &model, args, &device)?;
println!("{batch_index} {loss}"); println!("{batch_index} {loss}");
} }
if batch_index > 0 && batch_index % 1000 == 0 { if batch_index > 0 && batch_index % 1000 == 0 {

View File

@ -1,18 +0,0 @@
# candle-metavoice
MetaVoice-1B is a text-to-speech model trained on 100K hours of speech, more
details on the [model
card](https://huggingface.co/metavoiceio/metavoice-1B-v0.1).
Note that the current candle implementation suffers from some limitations as of
2024-03-02:
- The speaker embeddings are hardcoded.
- The generated audio file quality is weaker than the Python implementation,
probably because of some implementation discrepancies.
## Run an example
```bash
cargo run --example metavoice --release -- \\
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
```

View File

@ -1,342 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::Parser;
use std::io::Write;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::encodec;
use candle_transformers::models::metavoice::{
adapters, gpt, speaker_encoder, tokenizers, transformer,
};
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::api::sync::Api;
use rand::{distributions::Distribution, SeedableRng};
pub const ENCODEC_NTOKENS: u32 = 1024;
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
use symphonia::core::audio::Signal;
use symphonia::core::conv::FromSample;
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
use symphonia::core::audio::{AudioBufferRef, Signal};
let src = std::fs::File::open(path)?;
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
let hint = symphonia::core::probe::Hint::new();
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
let mut format = probed.format;
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
.expect("no supported audio tracks");
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &Default::default())
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
while let Ok(packet) = format.next_packet() {
while !format.metadata().is_latest() {
format.metadata().pop();
}
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum ArgDType {
F32,
F16,
Bf16,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The guidance scale.
#[arg(long, default_value_t = 3.0)]
guidance_scale: f64,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)]
temperature: f64,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The maximum number of tokens to generate for the first stage.
#[arg(long, default_value_t = 2000)]
max_tokens: u64,
/// The output file using the wav format.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long)]
first_stage_meta: Option<String>,
#[arg(long)]
first_stage_weights: Option<String>,
#[arg(long)]
second_stage_weights: Option<String>,
#[arg(long)]
speaker_encoder_weights: Option<String>,
#[arg(long)]
encodec_weights: Option<String>,
/// The speaker embeddings, either an audio files in which case they are extracted, or a
/// safetensors file with the embeddings already extracted.
#[arg(long)]
spk_emb: Option<String>,
#[arg(long, default_value = "f32")]
dtype: ArgDType,
}
fn mel_filters() -> Result<Vec<f32>> {
let mel_bytes = include_bytes!("melfilters40.bytes").as_slice();
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
Ok(mel_filters)
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let device = candle_examples::device(args.cpu)?;
let api = Api::new()?;
let repo = api.model("lmz/candle-metavoice".to_string());
let first_stage_meta = match &args.first_stage_meta {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage.meta.json")?,
};
let first_stage_meta: serde_json::Value =
serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?;
let first_stage_tokenizer = match first_stage_meta.as_object() {
None => anyhow::bail!("not a json object"),
Some(j) => match j.get("tokenizer") {
None => anyhow::bail!("no tokenizer key"),
Some(j) => j,
},
};
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
let first_stage_weights = match &args.first_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("first_stage.safetensors")?,
};
let second_stage_weights = match &args.second_stage_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("second_stage.safetensors")?,
};
let encodec_weights = match args.encodec_weights {
Some(w) => std::path::PathBuf::from(w),
None => Api::new()?
.model("facebook/encodec_24khz".to_string())
.get("model.safetensors")?,
};
let dtype = match args.dtype {
ArgDType::F32 => DType::F32,
ArgDType::F16 => DType::F16,
ArgDType::Bf16 => DType::BF16,
};
let first_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
let first_stage_config = transformer::Config::cfg1b_v0_1();
let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
let second_stage_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
let second_stage_config = gpt::Config::cfg1b_v0_1();
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
let encodec_device = if device.is_metal() {
&candle::Device::Cpu
} else {
&device
};
let encodec_vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
let encodec_config = encodec::Config::default();
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
println!("prompt: '{}'", args.prompt);
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
let mut tokens = prompt_tokens.clone();
println!("{tokens:?}");
let safetensors_embeddings = args
.spk_emb
.as_ref()
.map_or(true, |v| v.ends_with("safetensors"));
let spk_emb = if safetensors_embeddings {
let spk_emb_file = match &args.spk_emb {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("spk_emb.safetensors")?,
};
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
match spk_emb.get("spk_emb") {
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?,
}
} else {
let weights = match &args.speaker_encoder_weights {
Some(w) => std::path::PathBuf::from(w),
None => repo.get("speaker_encoder.safetensors")?,
};
let mel_filters = mel_filters()?;
let config = speaker_encoder::Config::cfg();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? };
let model = speaker_encoder::Model::new(config, vb)?;
let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?;
if sample_rate != 16_000 {
eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!")
}
model.embed_utterance(
&pcm,
&mel_filters,
/* rate */ 1.3,
/* min_c */ 0.75,
&device,
)?
};
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
// First stage generation.
for index in 0..args.max_tokens {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &device)?;
let input = Tensor::stack(&[&input, &input], 0)?;
let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
let logits0 = logits.i((0, 0))?;
let logits1 = logits.i((1, 0))?;
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
let logits = logits.to_dtype(DType::F32)?;
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
print!(".");
std::io::stdout().flush()?;
if next_token == 2048 {
break;
}
}
println!();
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
println!("text ids len: {}", text_ids.len());
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
// TODO: Use the config rather than hardcoding the offset here.
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
let mut hierarchies_in1 =
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
let mut hierarchies_in2 = [
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
ids2.as_slice(),
&[ENCODEC_NTOKENS],
]
.concat();
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
let logits = second_stage_model.forward(&in_x)?;
println!("sampling from logits...");
let mut codes = vec![];
for logits in logits.iter() {
let logits = logits.squeeze(0)?;
let (seq_len, _) = logits.dims2()?;
let mut codes_ = Vec::with_capacity(seq_len);
for step in 0..seq_len {
let logits = logits.i(step)?.to_dtype(DType::F32)?;
let logits = &(&logits / 1.0)?;
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
let sample = distr.sample(&mut rng) as u32;
codes_.push(sample)
}
codes.push(codes_)
}
let codes = Tensor::new(codes, &device)?.unsqueeze(0)?;
let codes = Tensor::cat(&[in_x, codes], 1)?;
println!("codes: {codes}");
let tilted_encodec = adapters::TiltedEncodec::new(ENCODEC_NTOKENS);
let codes = codes.i(0)?.to_vec2::<u32>()?;
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
println!("text_ids len: {:?}", text_ids.len());
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
println!("audio_ids shape: {:?}", audio_ids.shape());
let pcm = encodec_model.decode(&audio_ids)?;
println!("output pcm shape: {:?}", pcm.shape());
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
Ok(())
}

View File

@ -152,7 +152,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long)] #[arg(long)]

View File

@ -143,7 +143,7 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")] #[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]

View File

@ -0,0 +1,580 @@
use crate::nn::conv1d_weight_norm;
use candle::{DType, IndexOp, Module, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
#[derive(Debug, Clone, PartialEq)]
enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Config {
target_bandwidths: Vec<f64>,
sampling_rate: usize,
audio_channels: usize,
normalize: bool,
chunk_length_s: Option<usize>,
overlap: Option<usize>,
hidden_size: usize,
num_filters: usize,
num_residual_layers: usize,
upsampling_ratios: Vec<usize>,
norm_type: NormType,
kernel_size: usize,
last_kernel_size: usize,
residual_kernel_size: usize,
dilation_growth_rate: usize,
use_causal_conv: bool,
pad_mode: &'static str,
compress: usize,
num_lstm_layers: usize,
trim_right_ratio: f64,
codebook_size: usize,
codebook_dim: Option<usize>,
use_conv_shortcut: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
sampling_rate: 24_000,
audio_channels: 1,
normalize: false,
chunk_length_s: None,
overlap: None,
hidden_size: 128,
num_filters: 32,
num_residual_layers: 1,
upsampling_ratios: vec![8, 5, 4, 2],
norm_type: NormType::WeightNorm,
kernel_size: 7,
last_kernel_size: 7,
residual_kernel_size: 3,
dilation_growth_rate: 2,
use_causal_conv: true,
pad_mode: "reflect",
compress: 2,
num_lstm_layers: 2,
trim_right_ratio: 1.0,
codebook_size: 1024,
codebook_dim: None,
use_conv_shortcut: true,
}
}
}
impl Config {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
pub fn musicgen_small() -> Self {
Self {
audio_channels: 1,
chunk_length_s: None,
codebook_dim: Some(128),
codebook_size: 2048,
compress: 2,
dilation_growth_rate: 2,
hidden_size: 128,
kernel_size: 7,
last_kernel_size: 7,
norm_type: NormType::WeightNorm,
normalize: false,
num_filters: 64,
num_lstm_layers: 2,
num_residual_layers: 1,
overlap: None,
pad_mode: "reflect",
residual_kernel_size: 3,
sampling_rate: 32_000,
target_bandwidths: vec![2.2],
trim_right_ratio: 1.0,
upsampling_ratios: vec![8, 5, 4, 4],
use_causal_conv: false,
use_conv_shortcut: false,
}
}
fn codebook_dim(&self) -> usize {
self.codebook_dim.unwrap_or(self.codebook_size)
}
fn frame_rate(&self) -> usize {
let hop_length: usize = self.upsampling_ratios.iter().product();
(self.sampling_rate + hop_length - 1) / hop_length
}
fn num_quantizers(&self) -> usize {
let num = 1000f64
* self
.target_bandwidths
.last()
.expect("empty target_bandwidths");
(num as usize) / (self.frame_rate() * 10)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
#[derive(Debug)]
struct EncodecEuclideanCodebook {
inited: Tensor,
cluster_size: Tensor,
embed: Tensor,
embed_avg: Tensor,
}
impl EncodecEuclideanCodebook {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let inited = vb.get(1, "inited")?;
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
let embed = vb.get(e_shape, "embed")?;
let embed_avg = vb.get(e_shape, "embed_avg")?;
Ok(Self {
inited,
cluster_size,
embed,
embed_avg,
})
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.embed.embedding(embed_ind)?;
Ok(quantize)
}
}
#[derive(Debug)]
struct EncodecVectorQuantization {
codebook: EncodecEuclideanCodebook,
}
impl EncodecVectorQuantization {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
Ok(Self { codebook })
}
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.codebook.decode(embed_ind)?;
let quantize = quantize.transpose(1, 2)?;
Ok(quantize)
}
}
#[derive(Debug)]
struct EncodecResidualVectorQuantizer {
layers: Vec<EncodecVectorQuantization>,
}
impl EncodecResidualVectorQuantizer {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("layers");
let layers = (0..cfg.num_quantizers())
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
}
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
if codes.dim(0)? != self.layers.len() {
candle::bail!(
"codes shape {:?} does not match the number of quantization layers {}",
codes.shape(),
self.layers.len()
)
}
for (i, layer) in self.layers.iter().enumerate() {
let quantized = layer.decode(&codes.i(i)?)?;
quantized_out = quantized.broadcast_add(&quantized_out)?;
}
Ok(quantized_out)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
#[derive(Debug)]
struct EncodecLSTM {
layers: Vec<candle_nn::LSTM>,
}
impl EncodecLSTM {
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
let vb = &vb.pp("lstm");
let mut layers = vec![];
for layer_idx in 0..cfg.num_lstm_layers {
let config = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
layers.push(lstm)
}
Ok(Self { layers })
}
}
impl Module for EncodecLSTM {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
let mut xs = xs.clone();
for layer in self.layers.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
Ok(xs)
}
}
#[derive(Debug)]
struct EncodecConvTranspose1d {
weight_g: Tensor,
weight_v: Tensor,
bias: Tensor,
}
impl EncodecConvTranspose1d {
fn load(
in_c: usize,
out_c: usize,
k: usize,
_stride: usize,
vb: VarBuilder,
_cfg: &Config,
) -> Result<Self> {
let vb = &vb.pp("conv");
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
let bias = vb.get(out_c, "bias")?;
Ok(Self {
weight_g,
weight_v,
bias,
})
}
}
impl Module for EncodecConvTranspose1d {
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}
#[derive(Debug)]
struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
}
impl EncodecConv1d {
fn load(
in_c: usize,
out_c: usize,
kernel_size: usize,
stride: usize,
vb: VarBuilder,
cfg: &Config,
) -> Result<Self> {
let conv = match cfg.norm_type {
NormType::WeightNorm => conv1d_weight_norm(
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
NormType::None | NormType::TimeGroupNorm => conv1d(
in_c,
out_c,
kernel_size,
Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
})
}
}
impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// TODO: padding, depending on causal.
let xs = self.conv.forward(xs)?;
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
#[derive(Debug)]
struct EncodecResnetBlock {
block_conv1: EncodecConv1d,
block_conv2: EncodecConv1d,
shortcut: Option<EncodecConv1d>,
}
impl EncodecResnetBlock {
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
candle::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();
let block_conv1 =
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
layer.inc();
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
Some(conv)
} else {
None
};
Ok(Self {
block_conv1,
block_conv2,
shortcut,
})
}
}
impl Module for EncodecResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs.clone();
let xs = xs.elu(1.)?;
let xs = self.block_conv1.forward(&xs)?;
let xs = xs.elu(1.)?;
let xs = self.block_conv2.forward(&xs)?;
let xs = match &self.shortcut {
None => (xs + residual)?,
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
};
Ok(xs)
}
}
struct Layer<'a> {
vb: VarBuilder<'a>,
cnt: usize,
}
impl<'a> Layer<'a> {
fn new(vb: VarBuilder<'a>) -> Self {
Self { vb, cnt: 0 }
}
fn inc(&mut self) {
self.cnt += 1;
}
fn next(&mut self) -> VarBuilder {
let vb = self.vb.pp(&self.cnt.to_string());
self.cnt += 1;
vb
}
}
#[derive(Debug)]
struct EncodecEncoder {
init_conv: EncodecConv1d,
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
final_lstm: EncodecLSTM,
final_conv: EncodecConv1d,
}
impl EncodecEncoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let init_conv = EncodecConv1d::load(
cfg.audio_channels,
cfg.num_filters,
cfg.kernel_size,
1,
layer.next(),
cfg,
)?;
let mut sampling_layers = vec![];
let mut scaling = 1;
for &ratio in cfg.upsampling_ratios.iter().rev() {
let current_scale = scaling * cfg.num_filters;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale,
&[cfg.dilation_growth_rate.pow(j), 1],
layer.next(),
cfg,
)?;
resnets.push(resnet)
}
layer.inc(); // ELU
let conv1d = EncodecConv1d::load(
current_scale,
current_scale * 2,
ratio * 2,
ratio,
layer.next(),
cfg,
)?;
sampling_layers.push((resnets, conv1d));
scaling *= 2;
}
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters * scaling,
cfg.hidden_size,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
Ok(Self {
init_conv,
sampling_layers,
final_conv,
final_lstm,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?;
for (resnets, conv) in self.sampling_layers.iter() {
for resnet in resnets.iter() {
xs = xs.apply(resnet)?;
}
xs = xs.elu(1.0)?.apply(conv)?;
}
xs.apply(&self.final_lstm)?
.elu(1.0)?
.apply(&self.final_conv)
}
}
#[derive(Debug)]
struct EncodecDecoder {
init_conv: EncodecConv1d,
init_lstm: EncodecLSTM,
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
final_conv: EncodecConv1d,
}
impl EncodecDecoder {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
let init_conv = EncodecConv1d::load(
cfg.hidden_size,
cfg.num_filters * scaling,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
let mut sampling_layers = vec![];
for &ratio in cfg.upsampling_ratios.iter() {
let current_scale = scaling * cfg.num_filters;
layer.inc(); // ELU
let conv1d = EncodecConvTranspose1d::load(
current_scale,
current_scale / 2,
ratio * 2,
ratio,
layer.next(),
cfg,
)?;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::load(
current_scale / 2,
&[cfg.dilation_growth_rate.pow(j), 1],
layer.next(),
cfg,
)?;
resnets.push(resnet)
}
sampling_layers.push((conv1d, resnets));
scaling /= 2;
}
layer.inc(); // ELU
let final_conv = EncodecConv1d::load(
cfg.num_filters,
cfg.audio_channels,
cfg.last_kernel_size,
1,
layer.next(),
cfg,
)?;
Ok(Self {
init_conv,
init_lstm,
sampling_layers,
final_conv,
})
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
for (conv, resnets) in self.sampling_layers.iter() {
xs = xs.elu(1.)?.apply(conv)?;
for resnet in resnets.iter() {
xs = xs.apply(resnet)?
}
}
xs.elu(1.)?.apply(&self.final_conv)
}
}
#[derive(Debug)]
pub struct EncodecModel {
encoder: EncodecEncoder,
decoder: EncodecDecoder,
quantizer: EncodecResidualVectorQuantizer,
}
impl EncodecModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
Ok(Self {
encoder,
decoder,
quantizer,
})
}
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
todo!()
}
}

View File

@ -10,7 +10,9 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
mod encodec_model;
mod musicgen_model; mod musicgen_model;
mod nn;
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration}; use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};

View File

@ -1,9 +1,10 @@
use crate::encodec_model;
use candle::{DType, Device, Result, Tensor, D}; use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{ use candle_nn::{
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module, embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
VarBuilder, VarBuilder,
}; };
use candle_transformers::models::{encodec, t5}; use candle_transformers::models::t5;
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83 // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
@ -371,7 +372,7 @@ impl MusicgenForCausalLM {
#[derive(Debug)] #[derive(Debug)]
pub struct MusicgenForConditionalGeneration { pub struct MusicgenForConditionalGeneration {
pub text_encoder: t5::T5EncoderModel, pub text_encoder: t5::T5EncoderModel,
pub audio_encoder: encodec::Model, pub audio_encoder: crate::encodec_model::EncodecModel,
pub decoder: MusicgenForCausalLM, pub decoder: MusicgenForCausalLM,
cfg: GenConfig, cfg: GenConfig,
} }
@ -380,42 +381,15 @@ pub struct MusicgenForConditionalGeneration {
pub struct GenConfig { pub struct GenConfig {
musicgen: Config, musicgen: Config,
t5: t5::Config, t5: t5::Config,
encodec: encodec::Config, encodec: crate::encodec_model::Config,
} }
impl GenConfig { impl GenConfig {
pub fn small() -> Self { pub fn small() -> Self {
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
let encodec = encodec::Config {
audio_channels: 1,
chunk_length_s: None,
codebook_dim: Some(128),
codebook_size: 2048,
compress: 2,
dilation_growth_rate: 2,
hidden_size: 128,
kernel_size: 7,
last_kernel_size: 7,
norm_type: encodec::NormType::WeightNorm,
normalize: false,
num_filters: 64,
num_lstm_layers: 2,
num_residual_layers: 1,
overlap: None,
// This should be Reflect and not Replicate but Reflect does not work yet.
pad_mode: encodec::PadMode::Replicate,
residual_kernel_size: 3,
sampling_rate: 32_000,
target_bandwidths: vec![2.2],
trim_right_ratio: 1.0,
upsampling_ratios: vec![8, 5, 4, 4],
use_causal_conv: false,
use_conv_shortcut: false,
};
Self { Self {
musicgen: Config::musicgen_small(), musicgen: Config::musicgen_small(),
t5: t5::Config::musicgen_small(), t5: t5::Config::musicgen_small(),
encodec, encodec: encodec_model::Config::musicgen_small(),
} }
} }
} }
@ -427,7 +401,8 @@ impl MusicgenForConditionalGeneration {
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> { pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?; let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?; let audio_encoder =
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?; let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
Ok(Self { Ok(Self {
text_encoder, text_encoder,

View File

@ -0,0 +1,20 @@
use candle::Result;
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
pub fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}

View File

@ -212,14 +212,6 @@ struct Args {
#[arg(long)] #[arg(long)]
verbose_prompt: bool, verbose_prompt: bool,
/// Process prompt elements separately.
#[arg(long)]
split_prompt: bool,
/// Run on CPU rather than GPU even if a GPU is available.
#[arg(long)]
cpu: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty. /// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)] #[arg(long, default_value_t = 1.1)]
repeat_penalty: f32, repeat_penalty: f32,
@ -369,7 +361,7 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?; let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?; let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(false)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) { let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => { Some("gguf") => {
@ -495,20 +487,11 @@ fn main() -> anyhow::Result<()> {
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now(); let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt { let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?; let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)? logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in prompt_tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
}; };
let prompt_dt = start_prompt_processing.elapsed(); let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token); all_tokens.push(next_token);

View File

@ -2,8 +2,8 @@
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
with performance on par with transformer architectures. Several variants are with performance on par with transformer architectures. Several variants are
available, candle implements the v5 and v6 versions and can be used with available, candle implements the v5 version and can be used with Eagle 7B([blog
Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)). post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
```bash ```bash
$ cargo run --example rwkv --release -- --prompt "The smallest prime is " $ cargo run --example rwkv --release -- --prompt "The smallest prime is "

View File

@ -7,36 +7,13 @@ extern crate accelerate_src;
use anyhow::Result; use anyhow::Result;
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use candle_transformers::models::quantized_rwkv_v5::Model as Q5; use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
use candle_transformers::models::quantized_rwkv_v6::Model as Q6;
use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer};
use candle_transformers::models::rwkv_v6::Model as M6;
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
const EOS_TOKEN_ID: u32 = 261;
enum Model {
M5(M5),
Q5(Q5),
M6(M6),
Q6(Q6),
}
impl Model {
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
match self {
Self::M5(m) => m.forward(xs, state),
Self::Q5(m) => m.forward(xs, state),
Self::M6(m) => m.forward(xs, state),
Self::Q6(m) => m.forward(xs, state),
}
}
}
struct TextGeneration { struct TextGeneration {
model: Model, model: Model,
config: Config, config: Config,
@ -106,9 +83,6 @@ impl TextGeneration {
let next_token = self.logits_processor.sample(&logits)?; let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token); tokens.push(next_token);
generated_tokens += 1; generated_tokens += 1;
if next_token == EOS_TOKEN_ID || next_token == 0 {
break;
}
print!("{}", self.tokenizer.decode(&[next_token])?); print!("{}", self.tokenizer.decode(&[next_token])?);
std::io::stdout().flush()?; std::io::stdout().flush()?;
@ -129,7 +103,6 @@ enum Which {
Eagle7b, Eagle7b,
World1b5, World1b5,
World3b, World3b,
World6_1b6,
} }
impl std::fmt::Display for Which { impl std::fmt::Display for Which {
@ -144,7 +117,6 @@ impl Which {
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B", Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
Self::World1b5 => "RWKV/rwkv-5-world-1b5", Self::World1b5 => "RWKV/rwkv-5-world-1b5",
Self::World3b => "RWKV/rwkv-5-world-3b", Self::World3b => "RWKV/rwkv-5-world-3b",
Self::World6_1b6 => "paperfun/rwkv",
} }
} }
@ -152,7 +124,6 @@ impl Which {
match self { match self {
Self::Eagle7b => "refs/pr/1", Self::Eagle7b => "refs/pr/1",
Self::World1b5 | Self::World3b => "refs/pr/2", Self::World1b5 | Self::World3b => "refs/pr/2",
Self::World6_1b6 => "main",
} }
} }
} }
@ -205,9 +176,6 @@ struct Args {
#[arg(long)] #[arg(long)]
config_file: Option<String>, config_file: Option<String>,
#[arg(long)]
quantized: bool,
/// Penalty to be applied for repeating tokens, 1. means no penalty. /// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)] #[arg(long, default_value_t = 1.1)]
repeat_penalty: f32, repeat_penalty: f32,
@ -268,27 +236,7 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from) .map(std::path::PathBuf::from)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
None => { None => {
if args.quantized { vec![repo.get("model.safetensors")?]
vec![match args.which {
Which::World1b5 => api
.model("lmz/candle-rwkv".to_string())
.get("world1b5-q4k.gguf")?,
Which::World3b => api
.model("lmz/candle-rwkv".to_string())
.get("world3b-q4k.gguf")?,
Which::Eagle7b => api
.model("lmz/candle-rwkv".to_string())
.get("eagle7b-q4k.gguf")?,
Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?,
}]
} else {
vec![match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => {
repo.get("model.safetensors")?
}
Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?,
}]
}
} }
}; };
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
@ -297,21 +245,8 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let model = if args.quantized { let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let filename = &filenames[0]; let model = Model::new(&config, vb)?;
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?),
Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?),
}
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
match args.which {
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?),
Which::World6_1b6 => Model::M6(M6::new(&config, vb)?),
}
};
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new( let mut pipeline = TextGeneration::new(

View File

@ -1,28 +0,0 @@
# candle-segformer
- [HuggingFace Segformer Model Card][segformer]
- [`mit-b0` - An encoder only pretrained model][encoder]
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
## How to run the example
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
```bash
# run the image classification task
cargo run --example segformer classify <path-to-image>
# run the segmentation task
cargo run --example segformer segment <path-to-image>
```
Example output for classification:
```text
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
label: hamburger
```
[pr]: https://github.com/huggingface/candle/pull/1617
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
[encoder]: https://huggingface.co/nvidia/mit-b0
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512

View File

@ -1,752 +0,0 @@
[
{
"index": 1,
"color": "#787878",
"label": "wall"
},
{
"index": 2,
"color": "#B47878",
"label": "building;edifice"
},
{
"index": 3,
"color": "#06E6E6",
"label": "sky"
},
{
"index": 4,
"color": "#503232",
"label": "floor;flooring"
},
{
"index": 5,
"color": "#04C803",
"label": "tree"
},
{
"index": 6,
"color": "#787850",
"label": "ceiling"
},
{
"index": 7,
"color": "#8C8C8C",
"label": "road;route"
},
{
"index": 8,
"color": "#CC05FF",
"label": "bed"
},
{
"index": 9,
"color": "#E6E6E6",
"label": "windowpane;window"
},
{
"index": 10,
"color": "#04FA07",
"label": "grass"
},
{
"index": 11,
"color": "#E005FF",
"label": "cabinet"
},
{
"index": 12,
"color": "#EBFF07",
"label": "sidewalk;pavement"
},
{
"index": 13,
"color": "#96053D",
"label": "person;individual;someone;somebody;mortal;soul"
},
{
"index": 14,
"color": "#787846",
"label": "earth;ground"
},
{
"index": 15,
"color": "#08FF33",
"label": "door;double;door"
},
{
"index": 16,
"color": "#FF0652",
"label": "table"
},
{
"index": 17,
"color": "#8FFF8C",
"label": "mountain;mount"
},
{
"index": 18,
"color": "#CCFF04",
"label": "plant;flora;plant;life"
},
{
"index": 19,
"color": "#FF3307",
"label": "curtain;drape;drapery;mantle;pall"
},
{
"index": 20,
"color": "#CC4603",
"label": "chair"
},
{
"index": 21,
"color": "#0066C8",
"label": "car;auto;automobile;machine;motorcar"
},
{
"index": 22,
"color": "#3DE6FA",
"label": "water"
},
{
"index": 23,
"color": "#FF0633",
"label": "painting;picture"
},
{
"index": 24,
"color": "#0B66FF",
"label": "sofa;couch;lounge"
},
{
"index": 25,
"color": "#FF0747",
"label": "shelf"
},
{
"index": 26,
"color": "#FF09E0",
"label": "house"
},
{
"index": 27,
"color": "#0907E6",
"label": "sea"
},
{
"index": 28,
"color": "#DCDCDC",
"label": "mirror"
},
{
"index": 29,
"color": "#FF095C",
"label": "rug;carpet;carpeting"
},
{
"index": 30,
"color": "#7009FF",
"label": "field"
},
{
"index": 31,
"color": "#08FFD6",
"label": "armchair"
},
{
"index": 32,
"color": "#07FFE0",
"label": "seat"
},
{
"index": 33,
"color": "#FFB806",
"label": "fence;fencing"
},
{
"index": 34,
"color": "#0AFF47",
"label": "desk"
},
{
"index": 35,
"color": "#FF290A",
"label": "rock;stone"
},
{
"index": 36,
"color": "#07FFFF",
"label": "wardrobe;closet;press"
},
{
"index": 37,
"color": "#E0FF08",
"label": "lamp"
},
{
"index": 38,
"color": "#6608FF",
"label": "bathtub;bathing;tub;bath;tub"
},
{
"index": 39,
"color": "#FF3D06",
"label": "railing;rail"
},
{
"index": 40,
"color": "#FFC207",
"label": "cushion"
},
{
"index": 41,
"color": "#FF7A08",
"label": "base;pedestal;stand"
},
{
"index": 42,
"color": "#00FF14",
"label": "box"
},
{
"index": 43,
"color": "#FF0829",
"label": "column;pillar"
},
{
"index": 44,
"color": "#FF0599",
"label": "signboard;sign"
},
{
"index": 45,
"color": "#0633FF",
"label": "chest;of;drawers;chest;bureau;dresser"
},
{
"index": 46,
"color": "#EB0CFF",
"label": "counter"
},
{
"index": 47,
"color": "#A09614",
"label": "sand"
},
{
"index": 48,
"color": "#00A3FF",
"label": "sink"
},
{
"index": 49,
"color": "#8C8C8C",
"label": "skyscraper"
},
{
"index": 50,
"color": "#FA0A0F",
"label": "fireplace;hearth;open;fireplace"
},
{
"index": 51,
"color": "#14FF00",
"label": "refrigerator;icebox"
},
{
"index": 52,
"color": "#1FFF00",
"label": "grandstand;covered;stand"
},
{
"index": 53,
"color": "#FF1F00",
"label": "path"
},
{
"index": 54,
"color": "#FFE000",
"label": "stairs;steps"
},
{
"index": 55,
"color": "#99FF00",
"label": "runway"
},
{
"index": 56,
"color": "#0000FF",
"label": "case;display;case;showcase;vitrine"
},
{
"index": 57,
"color": "#FF4700",
"label": "pool;table;billiard;table;snooker;table"
},
{
"index": 58,
"color": "#00EBFF",
"label": "pillow"
},
{
"index": 59,
"color": "#00ADFF",
"label": "screen;door;screen"
},
{
"index": 60,
"color": "#1F00FF",
"label": "stairway;staircase"
},
{
"index": 61,
"color": "#0BC8C8",
"label": "river"
},
{
"index": 62,
"color": "#FF5200",
"label": "bridge;span"
},
{
"index": 63,
"color": "#00FFF5",
"label": "bookcase"
},
{
"index": 64,
"color": "#003DFF",
"label": "blind;screen"
},
{
"index": 65,
"color": "#00FF70",
"label": "coffee;table;cocktail;table"
},
{
"index": 66,
"color": "#00FF85",
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
},
{
"index": 67,
"color": "#FF0000",
"label": "flower"
},
{
"index": 68,
"color": "#FFA300",
"label": "book"
},
{
"index": 69,
"color": "#FF6600",
"label": "hill"
},
{
"index": 70,
"color": "#C2FF00",
"label": "bench"
},
{
"index": 71,
"color": "#008FFF",
"label": "countertop"
},
{
"index": 72,
"color": "#33FF00",
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
},
{
"index": 73,
"color": "#0052FF",
"label": "palm;palm;tree"
},
{
"index": 74,
"color": "#00FF29",
"label": "kitchen;island"
},
{
"index": 75,
"color": "#00FFAD",
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
},
{
"index": 76,
"color": "#0A00FF",
"label": "swivel;chair"
},
{
"index": 77,
"color": "#ADFF00",
"label": "boat"
},
{
"index": 78,
"color": "#00FF99",
"label": "bar"
},
{
"index": 79,
"color": "#FF5C00",
"label": "arcade;machine"
},
{
"index": 80,
"color": "#FF00FF",
"label": "hovel;hut;hutch;shack;shanty"
},
{
"index": 81,
"color": "#FF00F5",
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
},
{
"index": 82,
"color": "#FF0066",
"label": "towel"
},
{
"index": 83,
"color": "#FFAD00",
"label": "light;light;source"
},
{
"index": 84,
"color": "#FF0014",
"label": "truck;motortruck"
},
{
"index": 85,
"color": "#FFB8B8",
"label": "tower"
},
{
"index": 86,
"color": "#001FFF",
"label": "chandelier;pendant;pendent"
},
{
"index": 87,
"color": "#00FF3D",
"label": "awning;sunshade;sunblind"
},
{
"index": 88,
"color": "#0047FF",
"label": "streetlight;street;lamp"
},
{
"index": 89,
"color": "#FF00CC",
"label": "booth;cubicle;stall;kiosk"
},
{
"index": 90,
"color": "#00FFC2",
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
},
{
"index": 91,
"color": "#00FF52",
"label": "airplane;aeroplane;plane"
},
{
"index": 92,
"color": "#000AFF",
"label": "dirt;track"
},
{
"index": 93,
"color": "#0070FF",
"label": "apparel;wearing;apparel;dress;clothes"
},
{
"index": 94,
"color": "#3300FF",
"label": "pole"
},
{
"index": 95,
"color": "#00C2FF",
"label": "land;ground;soil"
},
{
"index": 96,
"color": "#007AFF",
"label": "bannister;banister;balustrade;balusters;handrail"
},
{
"index": 97,
"color": "#00FFA3",
"label": "escalator;moving;staircase;moving;stairway"
},
{
"index": 98,
"color": "#FF9900",
"label": "ottoman;pouf;pouffe;puff;hassock"
},
{
"index": 99,
"color": "#00FF0A",
"label": "bottle"
},
{
"index": 100,
"color": "#FF7000",
"label": "buffet;counter;sideboard"
},
{
"index": 101,
"color": "#8FFF00",
"label": "poster;posting;placard;notice;bill;card"
},
{
"index": 102,
"color": "#5200FF",
"label": "stage"
},
{
"index": 103,
"color": "#A3FF00",
"label": "van"
},
{
"index": 104,
"color": "#FFEB00",
"label": "ship"
},
{
"index": 105,
"color": "#08B8AA",
"label": "fountain"
},
{
"index": 106,
"color": "#8500FF",
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
},
{
"index": 107,
"color": "#00FF5C",
"label": "canopy"
},
{
"index": 108,
"color": "#B800FF",
"label": "washer;automatic;washer;washing;machine"
},
{
"index": 109,
"color": "#FF001F",
"label": "plaything;toy"
},
{
"index": 110,
"color": "#00B8FF",
"label": "swimming;pool;swimming;bath;natatorium"
},
{
"index": 111,
"color": "#00D6FF",
"label": "stool"
},
{
"index": 112,
"color": "#FF0070",
"label": "barrel;cask"
},
{
"index": 113,
"color": "#5CFF00",
"label": "basket;handbasket"
},
{
"index": 114,
"color": "#00E0FF",
"label": "waterfall;falls"
},
{
"index": 115,
"color": "#70E0FF",
"label": "tent;collapsible;shelter"
},
{
"index": 116,
"color": "#46B8A0",
"label": "bag"
},
{
"index": 117,
"color": "#A300FF",
"label": "minibike;motorbike"
},
{
"index": 118,
"color": "#9900FF",
"label": "cradle"
},
{
"index": 119,
"color": "#47FF00",
"label": "oven"
},
{
"index": 120,
"color": "#FF00A3",
"label": "ball"
},
{
"index": 121,
"color": "#FFCC00",
"label": "food;solid;food"
},
{
"index": 122,
"color": "#FF008F",
"label": "step;stair"
},
{
"index": 123,
"color": "#00FFEB",
"label": "tank;storage;tank"
},
{
"index": 124,
"color": "#85FF00",
"label": "trade;name;brand;name;brand;marque"
},
{
"index": 125,
"color": "#FF00EB",
"label": "microwave;microwave;oven"
},
{
"index": 126,
"color": "#F500FF",
"label": "pot;flowerpot"
},
{
"index": 127,
"color": "#FF007A",
"label": "animal;animate;being;beast;brute;creature;fauna"
},
{
"index": 128,
"color": "#FFF500",
"label": "bicycle;bike;wheel;cycle"
},
{
"index": 129,
"color": "#0ABED4",
"label": "lake"
},
{
"index": 130,
"color": "#D6FF00",
"label": "dishwasher;dish;washer;dishwashing;machine"
},
{
"index": 131,
"color": "#00CCFF",
"label": "screen;silver;screen;projection;screen"
},
{
"index": 132,
"color": "#1400FF",
"label": "blanket;cover"
},
{
"index": 133,
"color": "#FFFF00",
"label": "sculpture"
},
{
"index": 134,
"color": "#0099FF",
"label": "hood;exhaust;hood"
},
{
"index": 135,
"color": "#0029FF",
"label": "sconce"
},
{
"index": 136,
"color": "#00FFCC",
"label": "vase"
},
{
"index": 137,
"color": "#2900FF",
"label": "traffic;light;traffic;signal;stoplight"
},
{
"index": 138,
"color": "#29FF00",
"label": "tray"
},
{
"index": 139,
"color": "#AD00FF",
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
},
{
"index": 140,
"color": "#00F5FF",
"label": "fan"
},
{
"index": 141,
"color": "#4700FF",
"label": "pier;wharf;wharfage;dock"
},
{
"index": 142,
"color": "#7A00FF",
"label": "crt;screen"
},
{
"index": 143,
"color": "#00FFB8",
"label": "plate"
},
{
"index": 144,
"color": "#005CFF",
"label": "monitor;monitoring;device"
},
{
"index": 145,
"color": "#B8FF00",
"label": "bulletin;board;notice;board"
},
{
"index": 146,
"color": "#0085FF",
"label": "shower"
},
{
"index": 147,
"color": "#FFD600",
"label": "radiator"
},
{
"index": 148,
"color": "#19C2C2",
"label": "glass;drinking;glass"
},
{
"index": 149,
"color": "#66FF00",
"label": "clock"
},
{
"index": 150,
"color": "#5C00FF",
"label": "flag"
}
]

View File

@ -1,155 +0,0 @@
use candle::Device;
use candle::Module;
use candle_nn::VarBuilder;
use candle_transformers::models::segformer::{
Config, ImageClassificationModel, SemanticSegmentationModel,
};
use clap::{Args, Parser, Subcommand};
use image::Rgb;
use imageproc::integral_image::ArrayData;
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Parser)]
#[clap(about, version, long_about = None)]
struct CliArgs {
#[arg(long, help = "use cpu")]
cpu: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Args, Debug)]
struct SegmentationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
)]
model_name: String,
#[arg(
long,
help = "path to the label file in json format",
default_value = "candle-examples/examples/segformer/assets/labels.json"
)]
label_path: PathBuf,
#[arg(long, help = "path to for the output mask image")]
output_path: PathBuf,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Args, Debug)]
struct ClassificationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "paolinox/segformer-finetuned-food101"
)]
model_name: String,
#[arg(help = "path to image as input")]
image: PathBuf,
}
#[derive(Subcommand, Debug)]
enum Commands {
Segment(SegmentationArgs),
Classify(ClassificationArgs),
}
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
println!("loading model {} via huggingface hub", model_name);
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name.clone());
let model_file = api.get("model.safetensors")?;
println!("model {} downloaded and loaded", model_name);
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
let config = std::fs::read_to_string(api.get("config.json")?)?;
let config: Config = serde_json::from_str(&config)?;
println!("{:?}", config);
Ok((vb, config))
}
#[derive(Debug, serde::Deserialize)]
struct LabelItem {
index: u32,
color: String,
}
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
let label_file = std::fs::read_to_string(&args.label_path)?;
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
let label_colors: HashMap<u32, Rgb<u8>> = label_items
.iter()
.map(|x| {
(x.index - 1, {
let color = x.color.trim_start_matches('#');
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
Rgb([r, g, b])
})
})
.collect();
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = label_items.len();
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
let segmentations = model.forward(&image)?;
// generate a mask image
let mask = &segmentations.squeeze(0)?.argmax(0)?;
let (h, w) = mask.dims2()?;
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
let mask = mask
.iter()
.flat_map(|x| label_colors[x].data())
.collect::<Vec<u8>>();
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
// resize
let mask = image::DynamicImage::from(mask);
let mask = mask.resize_to_fill(
w as u32 * 4,
h as u32 * 4,
image::imageops::FilterType::CatmullRom,
);
mask.save(args.output_path.clone())?;
println!("mask image saved to {:?}", args.output_path);
Ok(())
}
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = 7;
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
let classification = model.forward(&image)?;
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
let classification = classification.squeeze(0)?;
println!(
"classification logits {:?}",
classification.to_vec1::<f32>()?
);
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
let label_id = format!("{}", label_id);
println!("label: {}", config.id2label[&label_id]);
Ok(())
}
pub fn main() -> anyhow::Result<()> {
let args = CliArgs::parse();
let device = candle_examples::device(args.cpu)?;
if let Commands::Segment(args) = args.command {
segmentation_task(args, &device)?
} else if let Commands::Classify(args) = args.command {
classification_task(args, &device)?
}
Ok(())
}

View File

@ -57,7 +57,7 @@ The downside is some long compilation time. You can set the
`/home/user/.candle` to ensures that the compilation artifacts are properly `/home/user/.candle` to ensures that the compilation artifacts are properly
cached. cached.
Enabling flash-attention requires both a feature flag, `--features flash-attn` Enabling flash-attention requires both a feature flag, `--feature flash-attn`
and using the command line flag `--use-flash-attn`. and using the command line flag `--use-flash-attn`.
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs

View File

@ -1,253 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::starcoder2::Model;
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => "bigcode/starcoder2-3b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let config_file = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let tokenizer_file = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => vec![repo.get("model.safetensors")?],
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,29 +0,0 @@
use candle::{Result, Tensor};
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
pub fn normalize_loudness(
wav: &Tensor,
sample_rate: u32,
loudness_compressor: bool,
) -> Result<Tensor> {
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
if energy < 2e-3 {
return Ok(wav.clone());
}
let wav_array = wav.to_vec1::<f32>()?;
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
meter.push(wav_array.into_iter());
let power = meter.as_100ms_windows();
let loudness = match crate::bs1770::gated_mean(power) {
None => return Ok(wav.clone()),
Some(gp) => gp.loudness_lkfs() as f64,
};
let delta_loudness = -14. - loudness;
let gain = 10f64.powf(delta_loudness / 20.);
let wav = (wav * gain)?;
if loudness_compressor {
wav.tanh()
} else {
Ok(wav)
}
}

View File

@ -1,506 +0,0 @@
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
// Copyright 2020 Ruud van Asseldonk
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// A copy of the License has been included in the root of the repository.
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
//!
//! This library offers the building blocks to perform BS.1770 loudness
//! measurements, but you need to put the pieces together yourself.
//!
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
//!
//! # Stereo integrated loudness example
//!
//! ```ignore
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
//! # [vec![0; 48_000], vec![0; 48_000]]
//! # }
//! #
//! let sample_rate_hz = 44_100;
//! let bits_per_sample = 16;
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
//!
//! // When converting integer samples to float, note that the maximum amplitude
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
//!
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
//! meter.into_100ms_windows()
//! }).collect();
//!
//! let stereo_power = bs1770::reduce_stereo(
//! channel_power[0].as_ref(),
//! channel_power[1].as_ref(),
//! );
//!
//! let gated_power = bs1770::gated_mean(
//! stereo_power.as_ref()
//! ).unwrap_or(bs1770::Power(0.0));
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
//! ```
use std::f32;
/// Coefficients for a 2nd-degree infinite impulse response filter.
///
/// Coefficient a0 is implicitly 1.0.
#[derive(Clone)]
struct Filter {
a1: f32,
a2: f32,
b0: f32,
b1: f32,
b2: f32,
// The past two input and output samples.
x1: f32,
x2: f32,
y1: f32,
y2: f32,
}
impl Filter {
/// Stage 1 of th BS.1770-4 pre-filter.
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let gain_db = 3.999_843_8;
let q = 0.707_175_25;
let center_hz = 1_681.974_5;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
let vh = 10.0_f32.powf(gain_db / 20.0);
let vb = vh.powf(0.499_666_78);
let a0 = 1.0 + k / q + k * k;
Filter {
b0: (vh + vb * k / q + k * k) / a0,
b1: 2.0 * (k * k - vh) / a0,
b2: (vh - vb * k / q + k * k) / a0,
a1: 2.0 * (k * k - 1.0) / a0,
a2: (1.0 - k / q + k * k) / a0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Stage 2 of th BS.1770-4 pre-filter.
pub fn high_pass(sample_rate_hz: f32) -> Filter {
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
let q = 0.500_327_05;
let center_hz = 38.135_47;
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
Filter {
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
b0: 1.0,
b1: -2.0,
b2: 1.0,
x1: 0.0,
x2: 0.0,
y1: 0.0,
y2: 0.0,
}
}
/// Feed the next input sample, get the next output sample.
#[inline(always)]
pub fn apply(&mut self, x0: f32) -> f32 {
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
- self.a1 * self.y1
- self.a2 * self.y2;
self.x2 = self.x1;
self.x1 = x0;
self.y2 = self.y1;
self.y1 = y0;
y0
}
}
/// Compensated sum, for summing many values of different orders of magnitude
/// accurately.
#[derive(Copy, Clone, PartialEq)]
struct Sum {
sum: f32,
residue: f32,
}
impl Sum {
#[inline(always)]
fn zero() -> Sum {
Sum {
sum: 0.0,
residue: 0.0,
}
}
#[inline(always)]
fn add(&mut self, x: f32) {
let sum = self.sum + (self.residue + x);
self.residue = (self.residue + x) - (sum - self.sum);
self.sum = sum;
}
}
/// The mean of the squares of the K-weighted samples in a window of time.
///
/// K-weighted power is equivalent to K-weighted loudness, the only difference
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
///
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
/// relative to Full Scale). Loudness units are related to decibels in the
/// following sense: boosting a signal that has a loudness of
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
/// bring the loudness to 0 LUFS.
///
/// K-weighting refers to a high-shelf and high-pass filter that model the
/// effect that humans perceive a certain amount of power in low frequencies to
/// be less loud than the same amount of power in higher frequencies. In this
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
///
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
/// mean square of the samples, if no input samples exceeded the full scale, the
/// power will be in the range [0.0, 1.0]. However, the power delivered by
/// multiple channels, which is a weighted sum over individual channel powers,
/// can exceed this range, because the weighted sum is not normalized.
#[derive(Copy, Clone, PartialEq, PartialOrd)]
pub struct Power(pub f32);
impl Power {
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
///
/// This is the inverse of `loudness_lkfs`.
pub fn from_lkfs(lkfs: f32) -> Power {
// The inverse of the formula below.
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
}
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
///
/// This is the inverse of `from_lkfs`.
pub fn loudness_lkfs(&self) -> f32 {
// Equation 2 (p.5) of BS.1770-4.
-0.691 + 10.0 * self.0.log10()
}
}
/// A `T` value for non-overlapping windows of audio, 100ms in length.
///
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
/// for non-overlapping windows of 100ms duration.
///
/// These non-overlapping 100ms windows can later be combined into overlapping
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
/// to perform a gated measurement, or they can be combined into even larger
/// windows for a momentary loudness measurement.
#[derive(Copy, Clone, Debug)]
pub struct Windows100ms<T> {
pub inner: T,
}
impl<T> Windows100ms<T> {
/// Wrap a new empty vector.
pub fn new() -> Windows100ms<Vec<T>> {
Windows100ms { inner: Vec::new() }
}
/// Apply `as_ref` to the inner value.
pub fn as_ref(&self) -> Windows100ms<&[Power]>
where
T: AsRef<[Power]>,
{
Windows100ms {
inner: self.inner.as_ref(),
}
}
/// Apply `as_mut` to the inner value.
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
where
T: AsMut<[Power]>,
{
Windows100ms {
inner: self.inner.as_mut(),
}
}
#[allow(clippy::len_without_is_empty)]
/// Apply `len` to the inner value.
pub fn len(&self) -> usize
where
T: AsRef<[Power]>,
{
self.inner.as_ref().len()
}
}
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
///
/// # Output
///
/// The output of the meter is an intermediate result in the form of power for
/// 100ms non-overlapping windows. The windows need to be processed further to
/// get one of the instantaneous, momentary, and integrated loudness
/// measurements defined in BS.1770.
///
/// The windows can also be inspected directly; the data is meaningful
/// on its own (the K-weighted power delivered in that window of time), but it
/// is not something that BS.1770 defines a term for.
///
/// # Multichannel audio
///
/// To perform a loudness measurement of multichannel audio, construct a
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
/// with e.g. `reduce_stereo`.
///
/// # Instantaneous loudness
///
/// The instantaneous loudness is the power over a 400ms window, so you can
/// average four 100ms windows. No special functionality is implemented to help
/// with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Momentary loudness
///
/// The momentary loudness is the power over a 3-second window, so you can
/// average thirty 100ms windows. No special functionality is implemented to
/// help with that at this time. ([Pull requests would be accepted.][contribute])
///
/// # Integrated loudness
///
/// Use `gated_mean` to perform an integrated loudness measurement:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
/// # let sample_rate_hz = 44_100;
/// # let samples_per_100ms = sample_rate_hz / 10;
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
/// .unwrap_or(bs1770::Power(0.0))
/// .loudness_lkfs();
/// ```
///
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
#[derive(Clone)]
pub struct ChannelLoudnessMeter {
/// The number of samples that fit in 100ms of audio.
samples_per_100ms: u32,
/// Stage 1 filter (head effects, high shelf).
filter_stage1: Filter,
/// Stage 2 filter (high-pass).
filter_stage2: Filter,
/// Sum of the squares over non-overlapping windows of 100ms.
windows: Windows100ms<Vec<Power>>,
/// The number of samples in the current unfinished window.
count: u32,
/// The sum of the squares of the samples in the current unfinished window.
square_sum: Sum,
}
impl ChannelLoudnessMeter {
/// Construct a new loudness meter for the given sample rate.
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
ChannelLoudnessMeter {
samples_per_100ms: sample_rate_hz / 10,
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
windows: Windows100ms::new(),
count: 0,
square_sum: Sum::zero(),
}
}
/// Feed input samples for loudness analysis.
///
/// # Full scale
///
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
/// input consists of signed integer samples, you can convert as follows:
///
/// ```ignore
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
/// # let bits_per_sample = 16_usize;
/// # let samples = &[0_i16];
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
/// // one bit is the sign bit.
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
/// ```
///
/// # Repeated calls
///
/// You can call `push` multiple times to feed multiple batches of samples.
/// This is equivalent to feeding a single chained iterator. The leftover of
/// samples that did not fill a full 100ms window is not discarded:
///
/// ```ignore
/// # use std::iter;
/// # use bs1770::ChannelLoudnessMeter;
/// let sample_rate_hz = 44_100;
/// let samples_per_100ms = sample_rate_hz / 10;
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
///
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
/// assert_eq!(meter.as_100ms_windows().len(), 0);
///
/// meter.push(iter::once(0.0));
/// assert_eq!(meter.as_100ms_windows().len(), 1);
/// ```
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
let normalizer = 1.0 / self.samples_per_100ms as f32;
// LLVM, if you could go ahead and inline those apply calls, and then
// unroll and vectorize the loop, that'd be terrific.
for x in samples {
let y = self.filter_stage1.apply(x);
let z = self.filter_stage2.apply(y);
self.square_sum.add(z * z);
self.count += 1;
// TODO: Should this branch be marked cold?
if self.count == self.samples_per_100ms {
let mean_squares = Power(self.square_sum.sum * normalizer);
self.windows.inner.push(mean_squares);
// We intentionally do not reset the residue. That way, leftover
// energy from this window is not lost, so for the file overall,
// the sum remains more accurate.
self.square_sum.sum = 0.0;
self.count = 0;
}
}
}
/// Return a reference to the 100ms windows analyzed so far.
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
self.windows.as_ref()
}
/// Return all 100ms windows analyzed so far.
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
self.windows
}
}
/// Combine power for multiple channels by taking a weighted sum.
///
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
/// sum over channels which is not normalized. This means that a stereo signal
/// is inherently louder than a mono signal. For a mono signal played back on
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
/// in the same signal for both channels.
pub fn reduce_stereo(
left: Windows100ms<&[Power]>,
right: Windows100ms<&[Power]>,
) -> Windows100ms<Vec<Power>> {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
let mut result = Vec::with_capacity(left.len());
for (l, r) in left.inner.iter().zip(right.inner) {
result.push(Power(l.0 + r.0));
}
Windows100ms { inner: result }
}
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
assert_eq!(
left.len(),
right.len(),
"Channels must have the same length."
);
for (l, r) in left.inner.iter_mut().zip(right.inner) {
l.0 += r.0;
}
}
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
///
/// The integrated loudness measurement is not just the average power over the
/// entire signal. BS.1770-4 defines defines two stages of gating that exclude
/// parts of the signal, to ensure that silent parts do not contribute to the
/// loudness measurment. This function performs that gating, and returns the
/// average power over the windows that were not excluded.
///
/// The result of this function is the integrated loudness measurement.
///
/// When no signal remains after applying the gate, this function returns
/// `None`. In particular, this happens when all of the signal is softer than
/// -70 LKFS, including a signal that consists of pure silence.
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
let absolute_threshold = Power::from_lkfs(-70.0);
// Iterate over all 400ms windows.
for window in windows_100ms.inner.windows(4) {
// Note that the sum over channels has already been performed at this point.
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
if gating_block_power > absolute_threshold {
gating_blocks.push(gating_block_power);
}
}
if gating_blocks.is_empty() {
return None;
}
// Compute the loudness after applying the absolute gate, in order to
// determine the threshold for the relative gate.
let mut sum_power = Sum::zero();
for &gating_block_power in &gating_blocks {
sum_power.add(gating_block_power.0);
}
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
// Stage 2: Apply the relative gate.
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
let mut sum_power = Sum::zero();
let mut n_blocks = 0_usize;
for &gating_block_power in &gating_blocks {
if gating_block_power > relative_threshold {
sum_power.add(gating_block_power.0);
n_blocks += 1;
}
}
if n_blocks == 0 {
return None;
}
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
Some(relative_gated_power)
}

View File

@ -1,9 +1,6 @@
pub mod audio;
pub mod bs1770;
pub mod coco_classes; pub mod coco_classes;
pub mod imagenet; pub mod imagenet;
pub mod token_output_stream; pub mod token_output_stream;
pub mod wav;
use candle::utils::{cuda_is_available, metal_is_available}; use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor}; use candle::{Device, Result, Tensor};

View File

@ -40,7 +40,7 @@ impl TokenOutputStream {
}; };
self.tokens.push(token); self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?; let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() {
let text = text.split_at(prev_text.len()); let text = text.split_at(prev_text.len());
self.prev_index = self.current_index; self.prev_index = self.current_index;
self.current_index = self.tokens.len(); self.current_index = self.tokens.len();

View File

@ -1,56 +0,0 @@
use std::io::prelude::*;
pub trait Sample {
fn to_i16(&self) -> i16;
}
impl Sample for f32 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for f64 {
fn to_i16(&self) -> i16 {
(self.clamp(-1.0, 1.0) * 32767.0) as i16
}
}
impl Sample for i16 {
fn to_i16(&self) -> i16 {
*self
}
}
pub fn write_pcm_as_wav<W: Write, S: Sample>(
w: &mut W,
samples: &[S],
sample_rate: u32,
) -> std::io::Result<()> {
let len = 12u32; // header
let len = len + 24u32; // fmt
let len = len + samples.len() as u32 * 2 + 8; // data
let n_channels = 1u16;
let bytes_per_second = sample_rate * 2 * n_channels as u32;
w.write_all(b"RIFF")?;
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
w.write_all(b"WAVE")?;
// Format block
w.write_all(b"fmt ")?;
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
w.write_all(&1u16.to_le_bytes())?; // PCM
w.write_all(&n_channels.to_le_bytes())?; // one channel
w.write_all(&sample_rate.to_le_bytes())?;
w.write_all(&bytes_per_second.to_le_bytes())?;
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
w.write_all(&16u16.to_le_bytes())?; // bits per sample
// Data block
w.write_all(b"data")?;
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
for sample in samples.iter() {
w.write_all(&sample.to_i16().to_le_bytes())?
}
Ok(())
}

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-flash-attn" name = "candle-flash-attn"
version = "0.4.1" version = "0.4.0"
edition = "2021" edition = "2021"
description = "Flash attention layer for the candle ML framework." description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.1" } candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.0" }
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies] [build-dependencies]

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-kernels" name = "candle-kernels"
version = "0.4.1" version = "0.4.0"
edition = "2021" edition = "2021"
description = "CUDA kernels for Candle" description = "CUDA kernels for Candle"

View File

@ -4,7 +4,6 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.4.1" version = "0.4.0"
edition = "2021" edition = "2021"
description = "Metal kernels for Candle" description = "Metal kernels for Candle"

View File

@ -5,7 +5,6 @@ use serde::Deserialize;
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum Activation { pub enum Activation {
#[default] #[default]
#[serde(alias = "gelu")]
Gelu, Gelu,
#[serde(alias = "gelu_new")] #[serde(alias = "gelu_new")]
NewGelu, NewGelu,
@ -20,8 +19,6 @@ pub enum Activation {
HardSwish, HardSwish,
Elu(f64), Elu(f64),
LeakyRelu(f64), LeakyRelu(f64),
#[serde(alias = "gelu_pytorch_tanh")]
GeluPytorchTanh,
} }
impl super::Module for Activation { impl super::Module for Activation {
@ -41,7 +38,6 @@ impl super::Module for Activation {
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?, Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha), &Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
Self::GeluPytorchTanh => xs.gelu(),
} }
} }
} }

View File

@ -76,7 +76,7 @@ pub struct ConvTranspose1dConfig {
pub output_padding: usize, pub output_padding: usize,
pub stride: usize, pub stride: usize,
pub dilation: usize, pub dilation: usize,
pub groups: usize, // TODO: support groups.
} }
impl Default for ConvTranspose1dConfig { impl Default for ConvTranspose1dConfig {
@ -86,7 +86,6 @@ impl Default for ConvTranspose1dConfig {
output_padding: 0, output_padding: 0,
stride: 1, stride: 1,
dilation: 1, dilation: 1,
groups: 1,
} }
} }
} }
@ -110,14 +109,6 @@ impl ConvTranspose1d {
pub fn config(&self) -> &ConvTranspose1dConfig { pub fn config(&self) -> &ConvTranspose1dConfig {
&self.config &self.config
} }
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
} }
impl crate::Module for ConvTranspose1d { impl crate::Module for ConvTranspose1d {
@ -128,13 +119,12 @@ impl crate::Module for ConvTranspose1d {
self.config.output_padding, self.config.output_padding,
self.config.stride, self.config.stride,
self.config.dilation, self.config.dilation,
self.config.groups,
)?; )?;
match &self.bias { match &self.bias {
None => Ok(x), None => Ok(x),
Some(bias) => { Some(bias) => {
let b = bias.dims1()?; let b = bias.dims1()?;
let bias = bias.reshape((1, b, 1))?; let bias = bias.reshape((1, b, 1, 1))?;
Ok(x.broadcast_add(&bias)?) Ok(x.broadcast_add(&bias)?)
} }
} }
@ -268,14 +258,6 @@ impl ConvTranspose2d {
pub fn config(&self) -> &ConvTranspose2dConfig { pub fn config(&self) -> &ConvTranspose2dConfig {
&self.config &self.config
} }
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
} }
impl crate::Module for ConvTranspose2d { impl crate::Module for ConvTranspose2d {
@ -320,22 +302,6 @@ pub fn conv1d(
Ok(Conv1d::new(ws, Some(bs), cfg)) Ok(Conv1d::new(ws, Some(bs), cfg))
} }
pub fn conv1d_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: Conv1dConfig,
vb: crate::VarBuilder,
) -> Result<Conv1d> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints(
(out_channels, in_channels / cfg.groups, kernel_size),
"weight",
init_ws,
)?;
Ok(Conv1d::new(ws, None, cfg))
}
pub fn conv_transpose1d( pub fn conv_transpose1d(
in_channels: usize, in_channels: usize,
out_channels: usize, out_channels: usize,
@ -348,11 +314,7 @@ pub fn conv_transpose1d(
lo: -bound, lo: -bound,
up: bound, up: bound,
}; };
let ws = vb.get_with_hints( let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
(in_channels, out_channels / cfg.groups, kernel_size),
"weight",
init,
)?;
let bs = vb.get_with_hints(out_channels, "bias", init)?; let bs = vb.get_with_hints(out_channels, "bias", init)?;
Ok(ConvTranspose1d::new(ws, Some(bs), cfg)) Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
} }
@ -369,11 +331,7 @@ pub fn conv_transpose1d_no_bias(
lo: -bound, lo: -bound,
up: bound, up: bound,
}; };
let ws = vb.get_with_hints( let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
(in_channels, out_channels / cfg.groups, kernel_size),
"weight",
init,
)?;
Ok(ConvTranspose1d::new(ws, None, cfg)) Ok(ConvTranspose1d::new(ws, None, cfg))
} }

View File

@ -19,16 +19,15 @@ pub mod var_map;
pub use activation::{prelu, Activation, PReLU}; pub use activation::{prelu, Activation, PReLU};
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use conv::{ pub use conv::{
conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias, conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
ConvTranspose1d, ConvTranspose1dConfig, ConvTranspose2d, ConvTranspose2dConfig,
}; };
pub use embedding::{embedding, Embedding}; pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT}; pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm}; pub use group_norm::{group_norm, GroupNorm};
pub use init::Init; pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_b, linear_no_bias, Linear}; pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout; pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};

View File

@ -57,34 +57,21 @@ impl super::Module for Linear {
/// Create or initialize a new linear layer. /// Create or initialize a new linear layer.
/// ///
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`. /// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> { pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt(); let bound = 1. / (in_dim as f64).sqrt();
let init_bs = crate::Init::Uniform { let init_bs = crate::Init::Uniform {
lo: -bound, lo: -bound,
up: bound, up: bound,
}; };
let bs = vb.get_with_hints(out_dim, "bias", init_bs)?; let bs = vs.get_with_hints(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs))) Ok(Linear::new(ws, Some(bs)))
} }
/// Create or initialize a new linear layer without biases. /// Create or initialize a new linear layer without biases.
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> { pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
Ok(Linear::new(ws, None)) Ok(Linear::new(ws, None))
} }
pub fn linear_b(
in_dim: usize,
out_dim: usize,
bias: bool,
vb: crate::VarBuilder,
) -> Result<Linear> {
if bias {
linear(in_dim, out_dim, vb)
} else {
linear_no_bias(in_dim, out_dim, vb)
}
}

View File

@ -197,7 +197,7 @@ impl RNN for LSTM {
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> { fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>(); let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
Tensor::stack(&states, 1) Tensor::cat(&states, 1)
} }
} }

View File

@ -70,7 +70,7 @@ impl VarMap {
/// ///
/// If an error is returned, some of the variables might have already been set to their new /// If an error is returned, some of the variables might have already been set to their new
/// values. /// values.
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<str>, V: AsRef<Tensor>>( pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
&mut self, &mut self,
iter: I, iter: I,
) -> Result<()> { ) -> Result<()> {

View File

@ -7,7 +7,7 @@ extern crate accelerate_src;
use candle::test_utils::{to_vec0_round, to_vec2_round}; use candle::test_utils::{to_vec0_round, to_vec2_round};
use anyhow::Result; use anyhow::Result;
use candle::{DType, Device, Tensor, Var}; use candle::{Device, Tensor, Var};
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD}; use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
#[test] #[test]
@ -121,40 +121,3 @@ fn adamw_linear_regression() -> Result<()> {
assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873); assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873);
Ok(()) Ok(())
} }
#[test]
fn adamw_linear_regression_varmap() -> Result<()> {
use candle_nn::Init::Const;
// Similar as the previous test but using a VarMap.
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
let gen = Linear::new(w_gen, Some(b_gen));
let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
let sample_ys = gen.forward(&sample_xs)?;
let mut var_map = candle_nn::VarMap::new();
let w = var_map.get((1, 2), "w", Const(0.), DType::F32, &Device::Cpu)?;
let b = var_map.get((), "b", Const(0.), DType::F32, &Device::Cpu)?;
let params = ParamsAdamW {
lr: 0.1,
..Default::default()
};
let mut opt = AdamW::new(var_map.all_vars(), params)?;
let lin = Linear::new(w, Some(b));
for _step in 0..100 {
let ys = lin.forward(&sample_xs)?;
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
opt.backward_step(&loss)?;
}
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[2.7257, 0.7097]]);
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 0.7873);
var_map.set([("w", Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?)].into_iter())?;
var_map.set([("b", Tensor::ones((), DType::F32, &Device::Cpu)?)].into_iter())?;
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[0., 0.]]);
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 1.);
Ok(())
}

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-onnx" name = "candle-onnx"
version = "0.4.1" version = "0.4.0"
edition = "2021" edition = "2021"
description = "ONNX support for Candle" description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.4.1" } candle = { path = "../candle-core", package = "candle-core", version = "0.4.0" }
candle-nn = { path = "../candle-nn", version = "0.4.1" } candle-nn = { path = "../candle-nn", version = "0.4.0" }
prost = "0.12.1" prost = "0.12.1"
[build-dependencies] [build-dependencies]

View File

@ -832,49 +832,7 @@ fn test_flatten_operation() -> Result<()> {
// #[test] // #[test]
// "Shape" // "Shape"
#[test] // #[test]
fn test_shape_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Shape".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec1::<i64>()?;
assert_eq!(results, vec![2, 2]);
Ok(())
}
// "Conv" // "Conv"
// #[test] // #[test]
@ -883,452 +841,31 @@ fn test_shape_operation() -> Result<()> {
// #[test] // #[test]
// "Abs" // "Abs"
#[test] // #[test]
fn test_abs_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Abs".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![-1.0f32, 2.0f32, -3.0f32, 4.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
Ok(())
}
// "Cos" // "Cos"
#[test] // #[test]
fn test_cos_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Cos".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
);
Ok(())
}
// "Sin" // "Sin"
#[test] // #[test]
fn test_sin_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Sin".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
Ok(())
}
// "Neg" // "Neg"
#[test] // #[test]
fn test_neg_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Neg".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![-1.0, -2.0], vec![-3.0, -4.0]]);
Ok(())
}
// "Erf" // "Erf"
// #[test] // #[test]
// "Tanh" // "Tanh"
#[test] // #[test]
fn test_tanh_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Tanh".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.0, 0.7615942], vec![0.9640276, 0.9950548]]
);
Ok(())
}
// "Sigmoid" // "Sigmoid"
#[test] // #[test]
fn test_sigmoid_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Sigmoid".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.5, 0.7310586], vec![0.880797, 0.95257413]]
);
Ok(())
}
// "Gelu" // "Gelu"
#[test] // #[test]
fn test_gelu_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Gelu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![0.0, 0.8413448], vec![1.9544997, 2.9959502]]
);
Ok(())
}
// "Relu" // "Relu"
#[test] // #[test]
fn test_relu_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Relu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![0.0, 1.0], vec![0.0, 3.0]]);
Ok(())
}
// "Constant" // "Constant"
// #[test] // #[test]

View File

@ -15,7 +15,6 @@ byteorder = { workspace = true }
candle = { workspace = true } candle = { workspace = true }
candle-flash-attn = { workspace = true, optional = true } candle-flash-attn = { workspace = true, optional = true }
candle-nn = { workspace = true } candle-nn = { workspace = true }
fancy-regex = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true } num-traits = { workspace = true }
rand = { workspace = true } rand = { workspace = true }

View File

@ -1,5 +1,15 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = if bias {
Some(vb.get(size2, "bias")?)
} else {
None
};
Ok(Linear::new(weight, bias))
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?; let weight = vb.get(size, "weight")?;

View File

@ -1,4 +1,4 @@
use crate::models::with_tracing::{linear_b as linear, Linear}; use crate::models::with_tracing::Linear;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
@ -51,6 +51,14 @@ impl Config {
} }
} }
fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
if bias {
crate::models::with_tracing::linear(in_dim, out_dim, vb)
} else {
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct RotaryEmbedding { struct RotaryEmbedding {
cache: Tensor, cache: Tensor,

View File

@ -1,460 +0,0 @@
//! EfficientViT (MSRA) inference implementation based on timm.
//!
//! See "EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"
//! https://arxiv.org/abs/2305.07027
//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py
use candle::{Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, ops::softmax, Conv2dConfig, Func,
VarBuilder,
};
#[derive(Clone)]
pub struct Config {
channels: [usize; 3],
blocks: [usize; 3],
heads: [usize; 3],
kernels: [usize; 4],
}
impl Config {
pub fn m0() -> Self {
Self {
channels: [64, 128, 192],
blocks: [1, 2, 3],
heads: [4, 4, 4],
kernels: [5, 5, 5, 5],
}
}
pub fn m1() -> Self {
Self {
channels: [128, 144, 192],
blocks: [1, 2, 3],
heads: [2, 3, 3],
kernels: [7, 5, 3, 3],
}
}
pub fn m2() -> Self {
Self {
channels: [128, 192, 224],
blocks: [1, 2, 3],
heads: [4, 3, 2],
kernels: [7, 5, 3, 3],
}
}
pub fn m3() -> Self {
Self {
channels: [128, 240, 320],
blocks: [1, 2, 3],
heads: [4, 3, 4],
kernels: [5, 5, 5, 5],
}
}
pub fn m4() -> Self {
Self {
channels: [128, 256, 384],
blocks: [1, 2, 3],
heads: [4, 4, 4],
kernels: [7, 5, 3, 3],
}
}
pub fn m5() -> Self {
Self {
channels: [192, 288, 384],
blocks: [1, 3, 4],
heads: [3, 3, 4],
kernels: [7, 5, 3, 3],
}
}
}
fn efficientvit_stemblock(
in_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride: 2,
padding: 1,
..Default::default()
};
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
let conv = conv2d_no_bias(in_channels, out_channels, 3, conv2d_cfg, vb.pp("conv"))?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&conv)?.apply_t(&bn, false)?;
Ok(xs)
}))
}
fn efficientvit_stem(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv1 = efficientvit_stemblock(3, dim / 8, vb.pp("conv1"))?;
let conv2 = efficientvit_stemblock(dim / 8, dim / 4, vb.pp("conv2"))?;
let conv3 = efficientvit_stemblock(dim / 4, dim / 2, vb.pp("conv3"))?;
let conv4 = efficientvit_stemblock(dim / 2, dim, vb.pp("conv4"))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&conv1)?
.relu()?
.apply(&conv2)?
.relu()?
.apply(&conv3)?
.relu()?
.apply(&conv4)?;
Ok(xs)
}))
}
fn depthwise_conv(
channels: usize,
kernel: usize,
stride: usize,
padding: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride,
padding,
groups: channels,
..Default::default()
};
let bn = batch_norm(channels, 1e-5, vb.pp("bn"))?;
let conv = conv2d_no_bias(channels, channels, kernel, conv2d_cfg, vb.pp("conv"))?;
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
}
fn pointwise_conv(
in_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
let conv = conv2d_no_bias(in_channels, out_channels, 1, conv2d_cfg, vb.pp("conv"))?;
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
}
fn conv_mlp(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
let pw1 = pointwise_conv(in_channels, out_channels, vb.pp("pw1"))?;
let pw2 = pointwise_conv(out_channels, in_channels, vb.pp("pw2"))?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&pw1)?.relu()?.apply(&pw2)?;
Ok(xs)
}))
}
// Fixed per-stage resolutions
const RESOLUTIONS: [usize; 3] = [14, 7, 4];
// Attention block
fn efficientvit_attn(
cfg: &Config,
stage: usize,
in_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let cga = cascaded_group_attn(cfg, stage, in_channels, vb)?;
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
let (b, c, h, w) = xs.dims4()?;
let win_res = 7; // Fixed window resolution
let pad_b = (win_res - h % win_res) % win_res;
let pad_r = (win_res - w % win_res) % win_res;
let ph = h + pad_b;
let pw = w + pad_r;
let nh = ph / win_res;
let nw = pw / win_res;
if RESOLUTIONS[stage] > win_res {
xs = xs.permute((0, 2, 3, 1))?;
xs = xs.pad_with_zeros(D::Minus1, 0, pad_r)?;
xs = xs.pad_with_zeros(D::Minus2, 0, pad_b)?;
xs = xs
.reshape((b, nh, win_res, nw, win_res, c))?
.transpose(2, 3)?;
xs = xs
.reshape((b * nh * nw, win_res, win_res, c))?
.permute((0, 3, 1, 2))?;
}
xs = xs.apply(&cga)?;
if RESOLUTIONS[stage] > win_res {
xs = xs
.permute((0, 2, 3, 1))?
.reshape((b, nh, nw, win_res, win_res, c))?;
xs = xs.transpose(2, 3)?.reshape((b, ph, pw, c))?;
xs = xs.permute((0, 3, 1, 2))?;
}
Ok(xs)
}))
}
// Cascaded group attention
fn cascaded_group_attn(
cfg: &Config,
stage: usize,
in_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let heads = cfg.heads[stage];
let key_dim = 16;
let val_dim = in_channels / heads;
let scale = (key_dim as f64).powf(-0.5);
let mut dws = Vec::with_capacity(heads);
let mut qkvs = Vec::with_capacity(heads);
for i in 0..heads {
dws.push(depthwise_conv(
key_dim,
cfg.kernels[i],
1,
cfg.kernels[i] / 2,
vb.pp(format!("dws.{i}")),
)?);
qkvs.push(pointwise_conv(
in_channels / heads,
in_channels / heads + 2 * key_dim,
vb.pp(format!("qkvs.{i}")),
)?);
}
let proj = pointwise_conv(in_channels, in_channels, vb.pp("proj.1"))?;
Ok(Func::new(move |xs| {
let (b, _, h, w) = xs.dims4()?;
let feats_in = xs.chunk(heads, 1)?;
let mut feats_out = Vec::with_capacity(heads);
let mut feat = feats_in[0].clone();
for i in 0..heads {
if i > 0 {
feat = (&feat + &feats_in[i])?;
}
feat = feat.apply(&qkvs[i])?;
let res = feat.reshape((b, (), h, w))?;
let q = res.narrow(1, 0, key_dim)?;
let k = res.narrow(1, key_dim, key_dim)?;
let v = res.narrow(1, 2 * key_dim, val_dim)?;
let q = q.apply(&dws[i])?;
let q = q.flatten_from(2)?;
let k = k.flatten_from(2)?;
let v = v.flatten_from(2)?;
let q = (q * scale)?;
let att = q.transpose(D::Minus2, D::Minus1)?.matmul(&k)?;
let att = softmax(&att, D::Minus1)?;
feat = v.matmul(&att.transpose(D::Minus2, D::Minus1)?)?;
feat = feat.reshape((b, val_dim, h, w))?;
feats_out.push(feat.clone());
}
let xs = Tensor::cat(&feats_out, 1)?;
let xs = xs.relu()?.apply(&proj)?;
Ok(xs)
}))
}
// Used by the downsampling layer
fn squeeze_and_excitation(
in_channels: usize,
squeeze_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
Ok(Func::new(move |xs| {
let residual = xs;
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
residual.broadcast_mul(&xs)
}))
}
// Used by the downsampling layer
fn patchmerge(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
let dim = in_channels;
let hid_dim = in_channels * 4;
let conv1 = pointwise_conv(dim, hid_dim, vb.pp("conv1"))?;
let conv2 = depthwise_conv(hid_dim, 3, 2, 1, vb.pp("conv2"))?;
let conv3 = pointwise_conv(hid_dim, out_channels, vb.pp("conv3"))?;
let se = squeeze_and_excitation(hid_dim, hid_dim / 4, vb.pp("se"))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&conv1)?
.relu()?
.apply(&conv2)?
.relu()?
.apply(&se)?
.apply(&conv3)?;
Ok(xs)
}))
}
// Used by the downsampling layer
fn res(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let dw = depthwise_conv(dim, 3, 1, 1, vb.pp("0.m"))?;
let mlp = conv_mlp(dim, dim * 2, vb.pp("1.m"))?;
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
xs = (&xs + &xs.apply(&dw)?)?;
xs = (&xs + &xs.apply(&mlp)?)?;
Ok(xs)
}))
}
// Downsampling
fn efficientvit_downsample(
in_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let res1 = res(in_channels, vb.pp("res1"))?;
let res2 = res(out_channels, vb.pp("res2"))?;
let patchmerge = patchmerge(in_channels, out_channels, vb.pp("patchmerge"))?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&res1)?.apply(&patchmerge)?.apply(&res2)?;
Ok(xs)
}))
}
fn efficientvit_block(
cfg: &Config,
stage: usize,
dim: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let dw0 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw0.m"))?;
let dw1 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw1.m"))?;
let ffn0 = conv_mlp(dim, dim * 2, vb.pp("ffn0.m"))?;
let ffn1 = conv_mlp(dim, dim * 2, vb.pp("ffn1.m"))?;
let attn = efficientvit_attn(cfg, stage, dim, vb.pp("mixer.m.attn"))?;
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
xs = (&xs + &xs.apply(&dw0)?)?;
xs = (&xs + &xs.apply(&ffn0)?)?;
xs = (&xs + &xs.apply(&attn)?)?;
xs = (&xs + &xs.apply(&dw1)?)?;
xs = (&xs + &xs.apply(&ffn1)?)?;
Ok(xs)
}))
}
// Each stage is made of blocks. There is a downsampling layer between stages.
fn efficientvit_stage(cfg: &Config, stage: usize, vb: VarBuilder) -> Result<Func<'static>> {
let nblocks = cfg.blocks[stage];
let mut blocks = Vec::with_capacity(nblocks + 1);
let in_channels = if stage > 0 {
cfg.channels[stage - 1]
} else {
cfg.channels[0]
};
let out_channels = cfg.channels[stage];
if stage > 0 {
blocks.push(efficientvit_downsample(
in_channels,
out_channels,
vb.pp("downsample"),
)?);
}
for i in 0..nblocks {
blocks.push(efficientvit_block(
cfg,
stage,
out_channels,
vb.pp(format!("blocks.{i}")),
)?);
}
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
for block in blocks.iter() {
xs = xs.apply(block)?
}
Ok(xs)
}))
}
// Classification head.
fn efficientvit_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = batch_norm(outputs, 1e-6, vb.pp("bn"))?;
let linear = linear(outputs, nclasses, vb.pp("linear"))?;
Ok(Func::new(move |xs| {
xs.apply_t(&norm, false)?.apply(&linear)
}))
}
// Build a efficientvit model for a given configuration.
fn efficientvit_model(
config: &Config,
nclasses: Option<usize>,
vb: VarBuilder,
) -> Result<Func<'static>> {
let cls = match nclasses {
None => None,
Some(nclasses) => {
let outputs = config.channels[2];
let head = efficientvit_head(outputs, nclasses, vb.pp("head"))?;
Some(head)
}
};
let stem_dim = config.channels[0];
let stem = efficientvit_stem(stem_dim, vb.pp("patch_embed"))?;
let vb = vb.pp("stages");
let stage1 = efficientvit_stage(config, 0, vb.pp(0))?;
let stage2 = efficientvit_stage(config, 1, vb.pp(1))?;
let stage3 = efficientvit_stage(config, 2, vb.pp(2))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&stem)?
.apply(&stage1)?
.apply(&stage2)?
.apply(&stage3)?
.mean(D::Minus2)?
.mean(D::Minus1)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.apply(cls),
}
}))
}
pub fn efficientvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
efficientvit_model(cfg, Some(nclasses), vb)
}
pub fn efficientvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
efficientvit_model(cfg, None, vb)
}

View File

@ -1,773 +0,0 @@
#![allow(unused)]
use candle::{DType, IndexOp, Layout, Module, Result, Shape, Tensor, D};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, ConvTranspose1d, VarBuilder};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
pub enum NormType {
WeightNorm,
TimeGroupNorm,
None,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
pub enum PadMode {
Constant,
Reflect,
Replicate,
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub target_bandwidths: Vec<f64>,
pub sampling_rate: usize,
pub audio_channels: usize,
pub normalize: bool,
pub chunk_length_s: Option<usize>,
pub overlap: Option<usize>,
pub hidden_size: usize,
pub num_filters: usize,
pub num_residual_layers: usize,
pub upsampling_ratios: Vec<usize>,
pub norm_type: NormType,
pub kernel_size: usize,
pub last_kernel_size: usize,
pub residual_kernel_size: usize,
pub dilation_growth_rate: usize,
pub use_causal_conv: bool,
pub pad_mode: PadMode,
pub compress: usize,
pub num_lstm_layers: usize,
pub trim_right_ratio: f64,
pub codebook_size: usize,
pub codebook_dim: Option<usize>,
pub use_conv_shortcut: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
sampling_rate: 24_000,
audio_channels: 1,
normalize: false,
chunk_length_s: None,
overlap: None,
hidden_size: 128,
num_filters: 32,
num_residual_layers: 1,
upsampling_ratios: vec![8, 5, 4, 2],
norm_type: NormType::WeightNorm,
kernel_size: 7,
last_kernel_size: 7,
residual_kernel_size: 3,
dilation_growth_rate: 2,
use_causal_conv: true,
// This should be PadMode::Reflect which is currently unsupported in candle.
pad_mode: PadMode::Replicate,
compress: 2,
num_lstm_layers: 2,
trim_right_ratio: 1.0,
codebook_size: 1024,
codebook_dim: None,
use_conv_shortcut: true,
}
}
}
impl Config {
fn codebook_dim(&self) -> usize {
self.codebook_dim.unwrap_or(self.hidden_size)
}
fn frame_rate(&self) -> usize {
let hop_length: usize = self.upsampling_ratios.iter().product();
(self.sampling_rate + hop_length - 1) / hop_length
}
fn num_quantizers(&self) -> usize {
let num = 1000f64
* self
.target_bandwidths
.last()
.expect("empty target_bandwidths");
(num as usize) / (self.frame_rate() * 10)
}
}
fn get_extra_padding_for_conv1d(
xs: &Tensor,
k_size: usize,
stride: usize,
padding_total: usize,
) -> Result<usize> {
let len = xs.dim(D::Minus1)?;
let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
let ideal_len =
((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
Ok(ideal_len.saturating_sub(len))
}
fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
match mode {
PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
}
}
// Applies weight norm for inference by recomputing the weight tensor. This
// does not apply to training.
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
pub fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = vb.get(out_c, "bias")?;
Ok(Conv1d::new(weight, Some(bias), config))
}
fn conv_transpose1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
bias: bool,
config: candle_nn::ConvTranspose1dConfig,
vb: VarBuilder,
) -> Result<ConvTranspose1d> {
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
let weight_v = vb.get((in_c, out_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
let bias = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
Ok(ConvTranspose1d::new(weight, bias, config))
}
struct CodebookEncode;
impl candle::CustomOp2 for CodebookEncode {
fn name(&self) -> &'static str {
"cb"
}
fn cpu_fwd(
&self,
lhs_storage: &candle::CpuStorage,
lhs_layout: &Layout,
rhs_storage: &candle::CpuStorage,
rhs_layout: &Layout,
) -> Result<(candle::CpuStorage, Shape)> {
use rayon::prelude::*;
let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
if lhs_dim2 != rhs_dim2 {
candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
}
if lhs_dim2 == 0 {
candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
}
let lhs = match lhs_layout.contiguous_offsets() {
None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
Some((o1, o2)) => {
let slice = lhs_storage.as_slice::<f32>()?;
&slice[o1..o2]
}
};
let rhs = match rhs_layout.contiguous_offsets() {
None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
Some((o1, o2)) => {
let slice = rhs_storage.as_slice::<f32>()?;
&slice[o1..o2]
}
};
let dst = (0..lhs_dim1)
.into_par_iter()
.map(|idx1| {
let mut where_min = 0;
let mut min_dist = f32::INFINITY;
let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
for idx2 in 0..rhs_dim1 {
let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
let mut dist = 0f32;
for (a, b) in lhs.iter().zip(rhs.iter()) {
dist += (a - b) * (a - b)
}
if dist < min_dist {
min_dist = dist;
where_min = idx2;
}
}
where_min as u32
})
.collect();
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (lhs_dim1,).into()))
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
#[derive(Clone, Debug)]
pub struct EuclideanCodebook {
inited: Tensor,
cluster_size: Tensor,
embed: candle_nn::Embedding,
embed_avg: Tensor,
c2: Tensor,
}
impl EuclideanCodebook {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let inited = vb.get(1, "inited")?;
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
let embed = vb.get(e_shape, "embed")?;
let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?;
let embed_avg = vb.get(e_shape, "embed_avg")?;
Ok(Self {
inited,
cluster_size,
embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
embed_avg,
c2,
})
}
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?;
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
codes.reshape(target_shape)
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut target_shape = xs.dims().to_vec();
target_shape.pop();
let xs = xs.flatten_to(D::Minus2)?;
let _ = xs.dims2()?;
let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?;
codes.reshape(target_shape)
}
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.embed.forward(embed_ind)?;
Ok(quantize)
}
}
#[derive(Clone, Debug)]
pub struct VectorQuantization {
codebook: EuclideanCodebook,
}
impl VectorQuantization {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let codebook = EuclideanCodebook::new(cfg, vb.pp("codebook"))?;
Ok(Self { codebook })
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.transpose(1, 2)?;
self.codebook.encode_slow(&xs)
}
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
let quantize = self.codebook.decode(embed_ind)?;
let quantize = quantize.transpose(1, 2)?;
Ok(quantize)
}
}
#[derive(Clone, Debug)]
pub struct ResidualVectorQuantizer {
layers: Vec<VectorQuantization>,
dtype: DType,
}
impl ResidualVectorQuantizer {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = &vb.pp("layers");
let layers = (0..cfg.num_quantizers())
.map(|i| VectorQuantization::new(cfg, vb.pp(i)))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
layers,
dtype: vb.dtype(),
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let mut codes = Vec::with_capacity(self.layers.len());
let mut residual = xs.clone();
for layer in self.layers.iter() {
let indices = layer.encode(&residual)?;
let quantized = layer.decode(&indices)?;
residual = (residual - quantized)?;
codes.push(indices)
}
Tensor::stack(&codes, 0)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), self.dtype, codes.device())?;
let ncodes = codes.dim(0)?;
if ncodes > self.layers.len() {
candle::bail!(
"codes shape {:?} does not match the number of quantization layers {}",
codes.shape(),
self.layers.len()
)
}
for (i, layer) in self.layers.iter().take(ncodes).enumerate() {
let quantized = layer.decode(&codes.i(i)?)?;
quantized_out = quantized.broadcast_add(&quantized_out)?;
}
Ok(quantized_out)
}
}
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
#[derive(Clone, Debug)]
pub struct EncodecLSTM {
layers: Vec<candle_nn::LSTM>,
}
impl EncodecLSTM {
pub fn new(dim: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = &vb.pp("lstm");
let mut layers = vec![];
for layer_idx in 0..cfg.num_lstm_layers {
let config = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
layers.push(lstm)
}
Ok(Self { layers })
}
}
impl Module for EncodecLSTM {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
// This is different from the Python transformers version as candle LSTM is batch first.
let xs = xs.t()?;
let residual = &xs;
let mut xs = xs.clone();
for layer in self.layers.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
let xs = (xs + residual)?.t()?;
Ok(xs)
}
}
#[derive(Clone, Debug)]
pub struct EncodecConvTranspose1d {
conv: ConvTranspose1d,
}
impl EncodecConvTranspose1d {
fn new(
in_c: usize,
out_c: usize,
k: usize,
stride: usize,
_cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let cfg = candle_nn::ConvTranspose1dConfig {
stride,
..Default::default()
};
let conv = conv_transpose1d_weight_norm(in_c, out_c, k, true, cfg, vb.pp("conv"))?;
Ok(Self { conv })
}
}
impl Module for EncodecConvTranspose1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.conv)
}
}
#[derive(Clone, Debug)]
pub struct EncodecConv1d {
causal: bool,
conv: Conv1d,
norm: Option<candle_nn::GroupNorm>,
pad_mode: PadMode,
}
impl EncodecConv1d {
pub fn new(
in_c: usize,
out_c: usize,
kernel_size: usize,
stride: usize,
dilation: usize,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let conv = match cfg.norm_type {
NormType::WeightNorm => conv1d_weight_norm(
in_c,
out_c,
kernel_size,
candle_nn::Conv1dConfig {
stride,
dilation,
..Default::default()
},
vb.pp("conv"),
)?,
NormType::None | NormType::TimeGroupNorm => conv1d(
in_c,
out_c,
kernel_size,
candle_nn::Conv1dConfig {
padding: 0,
stride,
groups: 1,
dilation: 1,
},
vb.pp("conv"),
)?,
};
let norm = match cfg.norm_type {
NormType::None | NormType::WeightNorm => None,
NormType::TimeGroupNorm => {
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
Some(gn)
}
};
Ok(Self {
causal: cfg.use_causal_conv,
conv,
norm,
pad_mode: cfg.pad_mode,
})
}
}
impl Module for EncodecConv1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _t, _c) = xs.dims3()?;
let k_size = self.conv.weight().dim(D::Minus1)?;
let conv_cfg = self.conv.config();
// Effective kernel size with dilations.
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
let padding_total = k_size - conv_cfg.stride;
let extra_padding =
get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
let xs = if self.causal {
pad1d(xs, padding_total, extra_padding, self.pad_mode)?
} else {
let padding_right = padding_total / 2;
let padding_left = padding_total - padding_right;
pad1d(
xs,
padding_left,
padding_right + extra_padding,
self.pad_mode,
)?
};
let xs = self.conv.forward(&xs)?;
match &self.norm {
None => Ok(xs),
Some(norm) => xs.apply(norm),
}
}
}
#[derive(Clone, Debug)]
pub struct EncodecResnetBlock {
block_conv1: EncodecConv1d,
block_conv2: EncodecConv1d,
shortcut: Option<EncodecConv1d>,
}
impl EncodecResnetBlock {
pub fn new(
dim: usize,
(dilation1, dilation2): (usize, usize),
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
// TODO: Apply dilations!
layer.inc();
let block_conv1 = EncodecConv1d::new(
dim,
h,
cfg.residual_kernel_size,
1,
dilation1,
cfg,
layer.next(),
)?;
layer.inc();
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?;
let shortcut = if cfg.use_conv_shortcut {
let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?;
Some(conv)
} else {
None
};
Ok(Self {
block_conv1,
block_conv2,
shortcut,
})
}
}
impl Module for EncodecResnetBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs.clone();
let xs = xs.elu(1.)?;
let xs = self.block_conv1.forward(&xs)?;
let xs = xs.elu(1.)?;
let xs = self.block_conv2.forward(&xs)?;
let xs = match &self.shortcut {
None => (xs + residual)?,
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
};
Ok(xs)
}
}
struct Layer<'a> {
vb: VarBuilder<'a>,
cnt: usize,
}
impl<'a> Layer<'a> {
fn new(vb: VarBuilder<'a>) -> Self {
Self { vb, cnt: 0 }
}
fn inc(&mut self) {
self.cnt += 1;
}
fn next(&mut self) -> VarBuilder {
let vb = self.vb.pp(&self.cnt.to_string());
self.cnt += 1;
vb
}
}
#[derive(Clone, Debug)]
pub struct Encoder {
init_conv: EncodecConv1d,
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
final_lstm: EncodecLSTM,
final_conv: EncodecConv1d,
}
impl Encoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let init_conv = EncodecConv1d::new(
cfg.audio_channels,
cfg.num_filters,
cfg.kernel_size,
1,
1,
cfg,
layer.next(),
)?;
let mut sampling_layers = vec![];
let mut scaling = 1;
for &ratio in cfg.upsampling_ratios.iter().rev() {
let current_scale = scaling * cfg.num_filters;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::new(
current_scale,
(cfg.dilation_growth_rate.pow(j), 1),
cfg,
layer.next(),
)?;
resnets.push(resnet)
}
layer.inc(); // ELU
let conv1d = EncodecConv1d::new(
current_scale,
current_scale * 2,
ratio * 2,
ratio,
1,
cfg,
layer.next(),
)?;
sampling_layers.push((resnets, conv1d));
scaling *= 2;
}
let final_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
layer.inc(); // ELU
let final_conv = EncodecConv1d::new(
cfg.num_filters * scaling,
cfg.hidden_size,
cfg.last_kernel_size,
1,
1,
cfg,
layer.next(),
)?;
Ok(Self {
init_conv,
sampling_layers,
final_conv,
final_lstm,
})
}
}
impl Module for Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?;
for (resnets, conv) in self.sampling_layers.iter() {
for resnet in resnets.iter() {
xs = xs.apply(resnet)?;
}
xs = xs.elu(1.0)?.apply(conv)?;
}
xs.apply(&self.final_lstm)?
.elu(1.0)?
.apply(&self.final_conv)
}
}
#[derive(Clone, Debug)]
pub struct Decoder {
init_conv: EncodecConv1d,
init_lstm: EncodecLSTM,
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
final_conv: EncodecConv1d,
}
impl Decoder {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let mut layer = Layer::new(vb.pp("layers"));
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
let init_conv = EncodecConv1d::new(
cfg.hidden_size,
cfg.num_filters * scaling,
cfg.last_kernel_size,
1,
1,
cfg,
layer.next(),
)?;
let init_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
let mut sampling_layers = vec![];
for &ratio in cfg.upsampling_ratios.iter() {
let current_scale = scaling * cfg.num_filters;
layer.inc(); // ELU
let conv1d = EncodecConvTranspose1d::new(
current_scale,
current_scale / 2,
ratio * 2,
ratio,
cfg,
layer.next(),
)?;
let mut resnets = vec![];
for j in 0..(cfg.num_residual_layers as u32) {
let resnet = EncodecResnetBlock::new(
current_scale / 2,
(cfg.dilation_growth_rate.pow(j), 1),
cfg,
layer.next(),
)?;
resnets.push(resnet)
}
sampling_layers.push((conv1d, resnets));
scaling /= 2;
}
layer.inc(); // ELU
let final_conv = EncodecConv1d::new(
cfg.num_filters,
cfg.audio_channels,
cfg.last_kernel_size,
1,
1,
cfg,
layer.next(),
)?;
Ok(Self {
init_conv,
init_lstm,
sampling_layers,
final_conv,
})
}
}
impl Module for Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
for (conv, resnets) in self.sampling_layers.iter() {
xs = xs.elu(1.)?.apply(conv)?;
for resnet in resnets.iter() {
xs = xs.apply(resnet)?
}
}
xs.elu(1.)?.apply(&self.final_conv)
}
}
#[derive(Debug)]
pub struct Model {
encoder: Encoder,
decoder: Decoder,
quantizer: ResidualVectorQuantizer,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
let quantizer = ResidualVectorQuantizer::new(cfg, vb.pp("quantizer"))?;
Ok(Self {
encoder,
decoder,
quantizer,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.encoder.forward(xs)?;
let codes = self.quantizer.encode(&xs)?;
codes.transpose(0, 1)
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let (_b_sz, _codebooks, _seqlen) = codes.dims3()?;
let codes = codes.transpose(0, 1)?;
let embeddings = self.quantizer.decode(&codes)?;
let outputs = self.decoder.forward(&embeddings)?;
Ok(outputs)
}
}

View File

@ -1,8 +1,18 @@
use candle::{DType, Device, Result, Tensor, D}; use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder}; use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000; const MAX_SEQ_LEN: usize = 5000;
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = if bias {
Some(vb.get(size2, "bias")?)
} else {
None
};
Ok(Linear::new(weight, bias))
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> { fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias), (Ok(weight), Ok(bias)) => (weight, bias),

View File

@ -1,381 +0,0 @@
use std::sync::Arc;
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b as linear, Linear, VarBuilder};
fn default_max_position_embeddings() -> usize {
4096
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub attention_bias: bool,
pub head_dim: usize,
pub hidden_act: candle_nn::Activation,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub vocab_size: usize,
#[serde(default = "default_max_position_embeddings")]
pub max_position_embeddings: usize,
}
#[derive(Debug, Clone)]
struct RmsNorm {
weight: Tensor,
eps: f64,
}
impl RmsNorm {
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(dim, "weight")?;
Ok(Self { weight, eps })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&(&self.weight + 1.0)?)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: candle_nn::Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = cfg.head_dim;
let bias = cfg.attention_bias;
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
rotary_emb,
kv_cache: None,
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, ()))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
device: Device,
dtype: DType,
hidden_size: usize,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
device: vb.device().clone(),
dtype: vb.dtype(),
hidden_size: cfg.hidden_size,
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -2,6 +2,7 @@ use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder}; use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096; pub const MAX_SEQ_LEN: usize = 4096;
@ -83,9 +84,10 @@ impl Config {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Cache { pub struct Cache {
masks: HashMap<usize, Tensor>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool, pub use_kv_cache: bool,
kvs: Vec<Option<(Tensor, Tensor)>>, #[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor, cos: Tensor,
sin: Tensor, sin: Tensor,
device: Device, device: Device,
@ -110,24 +112,25 @@ impl Cache {
let cos = idx_theta.cos()?.to_dtype(dtype)?; let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?; let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self { Ok(Self {
masks: HashMap::new(), masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache, use_kv_cache,
kvs: vec![None; config.num_hidden_layers], kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
device: device.clone(), device: device.clone(),
cos, cos,
sin, sin,
}) })
} }
fn mask(&mut self, t: usize) -> Result<Tensor> { fn mask(&self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) { let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone()) Ok(mask.clone())
} else { } else {
let mask: Vec<_> = (0..t) let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect(); .collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
self.masks.insert(t, mask.clone()); masks.insert(t, mask.clone());
Ok(mask) Ok(mask)
} }
} }
@ -161,6 +164,7 @@ struct CausalSelfAttention {
num_attention_heads: usize, num_attention_heads: usize,
num_key_value_heads: usize, num_key_value_heads: usize,
head_dim: usize, head_dim: usize,
cache: Cache,
use_flash_attn: bool, use_flash_attn: bool,
span: tracing::Span, span: tracing::Span,
span_rot: tracing::Span, span_rot: tracing::Span,
@ -183,11 +187,11 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
} }
impl CausalSelfAttention { impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter(); let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, hidden_size) = x.dims4()?; let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
let cos = cache.cos.narrow(0, index_pos, seq_len)?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = cache.sin.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?; let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?; let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?; let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
@ -197,13 +201,7 @@ impl CausalSelfAttention {
Ok(rope) Ok(rope)
} }
fn forward( fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let (b_sz, seq_len, hidden_size) = x.dims3()?; let (b_sz, seq_len, hidden_size) = x.dims3()?;
let q = self.q_proj.forward(x)?; let q = self.q_proj.forward(x)?;
@ -220,11 +218,12 @@ impl CausalSelfAttention {
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
.transpose(1, 2)?; .transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?; let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?; let mut k = self.apply_rotary_emb(&k, index_pos)?;
if cache.use_kv_cache { if self.cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] { let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1]; let k_seq_len = k.dims()[1];
@ -240,7 +239,7 @@ impl CausalSelfAttention {
.contiguous()? .contiguous()?
} }
} }
cache.kvs[block_idx] = Some((k.clone(), v.clone())) cache[block_idx] = Some((k.clone(), v.clone()))
} }
let k = self.repeat_kv(k)?; let k = self.repeat_kv(k)?;
@ -259,7 +258,7 @@ impl CausalSelfAttention {
let k = k.to_dtype(DType::F32)?; let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?; let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?; let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
@ -284,7 +283,7 @@ impl CausalSelfAttention {
} }
} }
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attn"); let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size; let size_in = cfg.hidden_size;
@ -302,6 +301,7 @@ impl CausalSelfAttention {
num_attention_heads: cfg.num_attention_heads, num_attention_heads: cfg.num_attention_heads,
num_key_value_heads: cfg.num_key_value_heads, num_key_value_heads: cfg.num_key_value_heads,
head_dim: cfg.hidden_size / cfg.num_attention_heads, head_dim: cfg.hidden_size / cfg.num_attention_heads,
cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn, use_flash_attn: cfg.use_flash_attn,
span, span,
span_rot, span_rot,
@ -357,25 +357,19 @@ struct Block {
} }
impl Block { impl Block {
fn forward( fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let _enter = self.span.enter(); let _enter = self.span.enter();
let residual = x; let residual = x;
let x = self.rms_1.forward(x)?; let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?; let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x; let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x) Ok(x)
} }
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "block"); let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let rms_2 = RmsNorm::load( let rms_2 = RmsNorm::load(
@ -402,11 +396,11 @@ pub struct Llama {
} }
impl Llama { impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> { pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?; let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?; let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() { for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?; x = block.forward(&x, index_pos, block_idx)?;
} }
let x = self.ln_f.forward(&x)?; let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?; let x = x.i((.., seq_len - 1, ..))?;
@ -414,12 +408,12 @@ impl Llama {
logits.to_dtype(DType::F32) logits.to_dtype(DType::F32)
} }
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.num_hidden_layers) let blocks: Vec<_> = (0..cfg.num_hidden_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap()) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
.collect(); .collect();
Ok(Self { Ok(Self {

View File

@ -2,6 +2,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::linear_no_bias as linear; use candle_nn::linear_no_bias as linear;
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Config { pub struct Config {
@ -69,11 +70,12 @@ impl Config {
} }
} }
#[derive(Debug, Clone)] #[derive(Clone)]
pub struct Cache { pub struct Cache {
masks: HashMap<usize, Tensor>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool, pub use_kv_cache: bool,
pub kvs: Vec<Option<(Tensor, Tensor)>>, #[allow(clippy::type_complexity)]
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
pub cos: Tensor, pub cos: Tensor,
pub sin: Tensor, pub sin: Tensor,
device: Device, device: Device,
@ -103,24 +105,25 @@ impl Cache {
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
Ok(Self { Ok(Self {
masks: HashMap::new(), masks: Arc::new(Mutex::new(HashMap::new())),
use_kv_cache, use_kv_cache,
kvs: vec![None; cfg.n_layers], kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
cos, cos,
sin, sin,
device: vb.device().clone(), device: vb.device().clone(),
}) })
} }
pub fn mask(&mut self, t: usize) -> Result<Tensor> { pub fn mask(&self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) { let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone()) Ok(mask.clone())
} else { } else {
let mask: Vec<_> = (0..t) let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect(); .collect();
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
self.masks.insert(t, mask.clone()); masks.insert(t, mask.clone());
Ok(mask) Ok(mask)
} }
} }
@ -130,7 +133,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)? xs / (xs.neg()?.exp()? + 1.0)?
} }
#[derive(Debug, Clone)]
struct CausalSelfAttention { struct CausalSelfAttention {
q_proj: Linear, q_proj: Linear,
k_proj: Linear, k_proj: Linear,
@ -139,13 +141,14 @@ struct CausalSelfAttention {
n_head: usize, n_head: usize,
n_key_value_head: usize, n_key_value_head: usize,
head_dim: usize, head_dim: usize,
cache: Cache,
} }
impl CausalSelfAttention { impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?; let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = cache.cos.i(index_pos..index_pos + seq_len)?; let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = cache.sin.i(index_pos..index_pos + seq_len)?; let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?; let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?; let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
@ -159,13 +162,7 @@ impl CausalSelfAttention {
Ok(rope) Ok(rope)
} }
fn forward( fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?; let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?; let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?; let k = self.k_proj.forward(x)?;
@ -175,15 +172,16 @@ impl CausalSelfAttention {
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?; let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?; let mut k = self.apply_rotary_emb(&k, index_pos)?;
if cache.use_kv_cache { if self.cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] { let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
} }
cache.kvs[block_idx] = Some((k.clone(), v.clone())) cache[block_idx] = Some((k.clone(), v.clone()))
} }
let k = self.repeat_kv(k)?; let k = self.repeat_kv(k)?;
@ -194,7 +192,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?; let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?; let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
@ -218,7 +216,7 @@ impl CausalSelfAttention {
} }
} }
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim; let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
@ -234,6 +232,7 @@ impl CausalSelfAttention {
n_head: cfg.n_heads, n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads, n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads, head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
}) })
} }
} }
@ -245,7 +244,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m) Ok(m)
} }
#[derive(Debug, Clone)]
struct Mlp { struct Mlp {
c_fc1: Linear, c_fc1: Linear,
c_fc2: Linear, c_fc2: Linear,
@ -276,7 +274,6 @@ impl Mlp {
} }
} }
#[derive(Debug, Clone)]
struct Block { struct Block {
rms_1: RmsNorm, rms_1: RmsNorm,
attn: CausalSelfAttention, attn: CausalSelfAttention,
@ -294,23 +291,17 @@ impl Block {
} }
} }
fn forward( fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let residual = x; let residual = x;
let x = self.rms_1.forward(x)?; let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?; let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x; let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x) Ok(x)
} }
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = let post_attention_layernorm =
@ -324,7 +315,6 @@ impl Block {
} }
} }
#[derive(Debug, Clone)]
pub struct Llama { pub struct Llama {
wte: Embedding, wte: Embedding,
blocks: Vec<Block>, blocks: Vec<Block>,
@ -334,23 +324,23 @@ pub struct Llama {
} }
impl Llama { impl Llama {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> { pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?; let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?; let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() { for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?; x = block.forward(&x, index_pos, block_idx)?;
} }
let x = self.ln_f.forward(&x)?; let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?; let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32) logits.to_dtype(DType::F32)
} }
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers) let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap()) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect(); .collect();
Ok(Self { Ok(Self {
wte, wte,

View File

@ -32,9 +32,9 @@ impl Config {
} }
pub struct State { pub struct State {
pub hs: Vec<Tensor>, hs: Vec<Tensor>,
pub prev_xs: Vec<[Tensor; D_CONV]>, prev_xs: Vec<[Tensor; D_CONV]>,
pub pos: usize, pos: usize,
} }
impl State { impl State {

View File

@ -1,968 +0,0 @@
use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
// Equivalent to torch.repeat_interleave
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
let img = img.unsqueeze(dim + 1)?;
let mut dims = img.dims().to_vec();
dims[dim + 1] = repeats;
img.broadcast_as(dims)?.flatten(dim, dim + 1)
}
pub mod speaker_encoder {
use super::*;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub sampling_rate: usize,
pub partial_n_frames: usize,
pub model_hidden_size: usize,
pub model_embedding_size: usize,
pub model_num_layers: usize,
pub mel_window_length: usize,
pub mel_window_step: usize,
pub mel_n_channels: usize,
}
impl Config {
pub fn cfg() -> Self {
Self {
sampling_rate: 16_000,
partial_n_frames: 160,
model_hidden_size: 256,
model_embedding_size: 256,
model_num_layers: 3,
mel_window_length: 25,
mel_window_step: 10,
mel_n_channels: 40,
}
}
}
pub struct Model {
lstms: Vec<candle_nn::LSTM>,
linear: Linear,
cfg: Config,
}
type Slice = (usize, usize);
impl Model {
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
let mut lstms = Vec::with_capacity(cfg.model_num_layers);
let vb_l = vb.pp("lstm");
for layer_idx in 0..cfg.model_num_layers {
let c = candle_nn::LSTMConfig {
layer_idx,
..Default::default()
};
let in_c = if layer_idx == 0 {
cfg.mel_n_channels
} else {
cfg.model_hidden_size
};
let lstm = candle_nn::lstm(in_c, cfg.model_hidden_size, c, vb_l.clone())?;
lstms.push(lstm)
}
let linear = linear_b(
cfg.model_hidden_size,
cfg.model_embedding_size,
true,
vb.pp("linear"),
)?;
Ok(Self { lstms, linear, cfg })
}
fn compute_partial_slices(
&self,
n_samples: usize,
rate: f64,
min_coverage: f64,
) -> (Vec<Slice>, Vec<Slice>) {
let c = &self.cfg;
// Compute how many frames separate two partial utterances
let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
let n_frames = n_samples / samples_per_frame + 1;
let frame_step =
(c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
// Compute the slices.
let mut wav_slices = vec![];
let mut mel_slices = vec![];
for i in (0..steps).step_by(frame_step) {
let mel_range = (i, i + c.partial_n_frames);
let wav_range = (
i * samples_per_frame,
(i + c.partial_n_frames) * samples_per_frame,
);
mel_slices.push(mel_range);
wav_slices.push(wav_range);
}
// Evaluate whether extra padding is warranted or not.
let last_wav_range = match wav_slices.last() {
None => return (wav_slices, mel_slices),
Some(l) => *l,
};
let coverage = (n_samples - last_wav_range.0) as f64
/ (last_wav_range.1 - last_wav_range.0) as f64;
if coverage > min_coverage && mel_slices.len() > 1 {
mel_slices.pop();
wav_slices.pop();
}
(wav_slices, mel_slices)
}
pub fn embed_utterance(
&self,
wav: &[f32],
mel_filters: &[f32],
rate: f64,
min_c: f64,
device: &Device,
) -> Result<Tensor> {
let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
let max_wave_length = match wav_slices.last() {
Some(v) => v.1,
None => candle::bail!("empty wav slices"),
};
let wav = if max_wave_length > wav.len() {
let mut wav = wav.to_vec();
wav.resize(max_wave_length - wav.len(), 0.0);
std::borrow::Cow::Owned(wav)
} else {
std::borrow::Cow::Borrowed(wav)
};
let mel = crate::models::whisper::audio::log_mel_spectrogram_(
wav.as_ref(),
mel_filters,
/* fft_size */ self.cfg.mel_window_length,
/* fft_step */ self.cfg.mel_window_step,
self.cfg.mel_n_channels,
false,
);
let mels = mel_slices
.iter()
.flat_map(|s| [mel[s.0], mel[s.1]])
.collect::<Vec<_>>();
let mels = Tensor::from_vec(mels, (1, mel_slices.len(), 2), device)?;
let partial_embeds = self.forward(&mels)?;
let raw_embed = partial_embeds.mean(0)?;
let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
raw_embed.broadcast_div(&norm)
}
}
impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;
// This is different from the Python transformers version as candle LSTM is batch first.
let xs = xs.t()?;
let mut xs = xs.clone();
for layer in self.lstms.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
let xs = xs.t()?;
let embeds_raw = xs.apply(&self.linear)?.relu()?;
let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
embeds_raw.broadcast_div(&norm)
}
}
}
type Rank = u32;
pub mod tokenizers {
use super::*;
use std::collections::HashMap;
pub struct BPE {
pub re: fancy_regex::Regex,
pub end_of_text: usize,
pub offset: usize,
pub ranks: HashMap<Vec<u8>, Rank>,
}
impl BPE {
pub fn from_json(json: &serde_json::Value, end_of_text: usize) -> Result<Self> {
let json = match json.as_object() {
None => candle::bail!("json value is not an object"),
Some(json) => json,
};
let re = match json.get("pat_str") {
None => candle::bail!("json object has no pat_str field"),
Some(pat_str) => match pat_str.as_str() {
None => candle::bail!("pat_str field is not a string"),
Some(pat_str) => fancy_regex::Regex::new(pat_str).map_err(E::wrap)?,
},
};
let offset = match json.get("offset") {
None => candle::bail!("json object has no offset field"),
Some(offset) => match offset.as_u64() {
None => candle::bail!("offset field is not a positive int"),
Some(offset) => offset as usize,
},
};
let mut ranks = HashMap::new();
for id in 0u8..=255 {
ranks.insert(vec![id], id as u32);
}
let mergeable_ranks = match json.get("mergeable_ranks") {
None => candle::bail!("json object has no mergeable_ranks field"),
Some(mr) => match mr.as_object() {
None => candle::bail!("mergeable_ranks is not an object"),
Some(mr) => mr,
},
};
for (key, value) in mergeable_ranks.iter() {
let value = match value.as_u64() {
None => candle::bail!("mergeable_ranks '{key}' is not a u64"),
Some(value) => value as u32,
};
if value < 256 {
continue;
}
// No escaping for other keys.
let key = key.as_bytes().to_vec();
ranks.insert(key, value);
}
Ok(Self {
re,
end_of_text,
offset,
ranks,
})
}
// Taken from:
// https://github.com/openai/tiktoken/blob/1b9faf2779855124f05174adf1383e53689ed94b/src/lib.rs#L16C1-L82C2
fn _byte_pair_merge(&self, piece: &[u8]) -> Vec<(usize, Rank)> {
// This is a vector of (start, rank).
// The rank is of the pair starting at position start.
let mut parts = Vec::with_capacity(piece.len() + 1);
// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
// merge priority from token index or to prevent specific token merges.
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
for i in 0..piece.len() - 1 {
let rank = *self.ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
if rank < min_rank.0 {
min_rank = (rank, i);
}
parts.push((i, rank));
}
parts.push((piece.len() - 1, Rank::MAX));
parts.push((piece.len(), Rank::MAX));
let get_rank = {
#[inline(always)]
|parts: &Vec<(usize, Rank)>, i: usize| {
if (i + 3) < parts.len() {
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
// parts[i + 1], see comment in the main loop.
*self
.ranks
.get(&piece[parts[i].0..parts[i + 3].0])
.unwrap_or(&Rank::MAX)
} else {
Rank::MAX
}
}
};
// If you have n parts and m merges, this does O(mn) work.
// We could do something with a heap and do O(m log n) work.
// n is often very small so considerations like cache-locality outweigh the algorithmic
// complexity downsides of the `parts` vector.
while min_rank.0 != Rank::MAX {
let i = min_rank.1;
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
// `parts.remove(i + 1)` will thrash the cache.
if i > 0 {
parts[i - 1].1 = get_rank(&parts, i - 1);
}
parts[i].1 = get_rank(&parts, i);
parts.remove(i + 1);
min_rank = (Rank::MAX, usize::MAX);
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
if rank < min_rank.0 {
min_rank = (rank, i);
}
}
}
parts
}
pub fn byte_pair_encode(&self, piece: &[u8]) -> Vec<Rank> {
if piece.is_empty() {
return Vec::new();
}
if piece.len() == 1 {
return vec![self.ranks[piece]];
}
assert!(piece.len() > 1);
self._byte_pair_merge(piece)
.windows(2)
.map(|part| self.ranks[&piece[part[0].0..part[1].0]])
.collect()
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
let mut bpe_tokens: Vec<u32> = Vec::new();
for word in self.re.find_iter(text) {
let word = word.map_err(E::wrap)?;
let word_tokens = self.byte_pair_encode(word.as_str().as_bytes());
for &token in word_tokens.iter() {
bpe_tokens.push(token + self.offset as u32)
}
}
bpe_tokens.push((self.end_of_text + self.offset) as u32);
Ok(bpe_tokens)
}
}
}
pub mod gpt {
use super::*;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum NormType {
LayerNorm,
RMSNorm,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum AttnKernelType {
Fa2,
TorchAttn,
Hand,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub enum NonLinearityType {
Gelu,
Swiglu,
}
enum Norm {
RMSNorm(candle_nn::RmsNorm),
LayerNorm(candle_nn::LayerNorm),
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L27
#[derive(Debug, Clone)]
pub struct Config {
pub block_size: usize,
pub vocab_sizes: Vec<usize>,
pub target_vocab_sizes: Vec<usize>,
pub n_layer: usize,
pub n_head: usize,
pub n_embd: usize,
pub bias: bool,
pub causal: bool,
pub spk_emb_on_text: bool,
pub norm_type: NormType,
pub rmsnorm_eps: f64,
pub nonlinearity_type: NonLinearityType,
pub swiglu_multiple_of: Option<usize>,
pub attn_kernel_type: AttnKernelType,
pub kv_cache_enabled: bool,
}
impl Config {
pub fn cfg1b_v0_1() -> Self {
Self {
n_layer: 6,
n_head: 6,
n_embd: 384,
block_size: 1024,
bias: false,
vocab_sizes: vec![1538, 1025],
causal: false,
target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
swiglu_multiple_of: Some(256),
norm_type: NormType::LayerNorm,
kv_cache_enabled: false,
attn_kernel_type: AttnKernelType::TorchAttn,
spk_emb_on_text: true,
nonlinearity_type: NonLinearityType::Gelu,
rmsnorm_eps: 1e-5,
}
}
}
impl Norm {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
match cfg.norm_type {
NormType::RMSNorm => {
let rms_norm = candle_nn::rms_norm(cfg.n_embd, cfg.rmsnorm_eps, vb)?;
Ok(Self::RMSNorm(rms_norm))
}
NormType::LayerNorm => {
let ln_cfg = candle_nn::LayerNormConfig {
affine: cfg.bias,
..Default::default()
};
let layer_norm = candle_nn::layer_norm(cfg.n_embd, ln_cfg, vb)?;
Ok(Self::LayerNorm(layer_norm))
}
}
}
}
impl Module for Norm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::RMSNorm(m) => m.forward(xs),
Self::LayerNorm(m) => m.forward(xs),
}
}
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/attn.py#L18
struct SelfAttention {
c_attn: Linear,
c_proj: Linear,
n_head: usize,
}
impl SelfAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
// The different attention variants are likely to be identical but still we only accept
// TorchAttn for now.
if cfg.attn_kernel_type != AttnKernelType::TorchAttn {
candle::bail!("only TorchAttn is supported")
}
if cfg.kv_cache_enabled {
candle::bail!("kv_cache_enabled=true is not supported")
}
let c_attn = linear_b(cfg.n_embd, cfg.n_embd * 3, cfg.bias, vb.pp("c_attn"))?;
let c_proj = linear_b(cfg.n_embd, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
Ok(Self {
c_attn,
c_proj,
n_head: cfg.n_head,
})
}
}
impl Module for SelfAttention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, t, c) = xs.dims3()?;
let c_x = xs
.apply(&self.c_attn)?
.reshape((b, t, 3, self.n_head, c / self.n_head))?;
let q = c_x.i((.., .., 0))?;
let k = c_x.i((.., .., 1))?;
let v = c_x.i((.., .., 2))?;
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
// TODO: causal mask
let att = candle_nn::ops::softmax_last_dim(&att)?;
let att = att.matmul(&v)?.transpose(1, 2)?;
att.reshape((b, t, c))?.apply(&self.c_proj)
}
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/layers.py#L43
#[allow(clippy::upper_case_acronyms)]
enum MLP {
Gelu {
c_fc: Linear,
c_proj: Linear,
},
Swiglu {
w1: Linear,
w3: Linear,
c_proj: Linear,
},
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_dim = 4 * cfg.n_embd;
let slf = match cfg.nonlinearity_type {
NonLinearityType::Gelu => {
let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
Self::Gelu { c_fc, c_proj }
}
NonLinearityType::Swiglu => {
let hidden_dim = (2 * hidden_dim) / 3;
let swiglu_multiple_of = match cfg.swiglu_multiple_of {
None => candle::bail!("swiglu-multiple-of has to be set"),
Some(smo) => smo,
};
let hidden_dim = swiglu_multiple_of * (hidden_dim + swiglu_multiple_of - 1)
/ swiglu_multiple_of;
let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
Self::Swiglu { w1, w3, c_proj }
}
};
Ok(slf)
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Gelu { c_fc, c_proj } => xs.apply(c_fc)?.gelu()?.apply(c_proj),
Self::Swiglu { w1, w3, c_proj } => {
let w1 = xs.apply(w1)?;
let w3 = xs.apply(w3)?;
(w1.silu()? * w3)?.apply(c_proj)
}
}
}
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/combined.py#L7
struct Block {
ln_1: Norm,
ln_2: Norm,
attn: SelfAttention,
mlp: MLP,
}
impl Block {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln_1 = Norm::new(cfg, vb.pp("ln_1"))?;
let ln_2 = Norm::new(cfg, vb.pp("ln_2"))?;
let attn = SelfAttention::new(cfg, vb.pp("attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
Ok(Block {
ln_1,
ln_2,
attn,
mlp,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
Ok(xs)
}
}
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L79
#[allow(clippy::upper_case_acronyms)]
pub struct Model {
wtes: Vec<candle_nn::Embedding>,
wpe: candle_nn::Embedding,
h: Vec<Block>,
ln_f: Norm,
lm_heads: Vec<Linear>,
cfg: Config,
dtype: DType,
}
impl Model {
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
let vb_t = vb.pp("transformer");
let ln_f = Norm::new(&cfg, vb_t.pp("ln_f"))?;
let mut wtes = Vec::with_capacity(cfg.vocab_sizes.len());
let vb_w = vb_t.pp("wtes");
for (idx, vocab_size) in cfg.vocab_sizes.iter().enumerate() {
let wte = candle_nn::embedding(*vocab_size, cfg.n_embd, vb_w.pp(idx))?;
wtes.push(wte)
}
let wpe = candle_nn::embedding(cfg.block_size, cfg.n_embd, vb_t.pp("wpe"))?;
let mut h = Vec::with_capacity(cfg.n_layer);
let vb_h = vb_t.pp("h");
for idx in 0..cfg.n_layer {
let block = Block::new(&cfg, vb_h.pp(idx))?;
h.push(block)
}
let mut lm_heads = Vec::with_capacity(cfg.target_vocab_sizes.len());
let vb_l = vb.pp("lm_heads");
for (idx, vocab_size) in cfg.target_vocab_sizes.iter().enumerate() {
let head = linear_b(cfg.n_embd, *vocab_size, false, vb_l.pp(idx))?;
lm_heads.push(head)
}
Ok(Self {
wtes,
wpe,
h,
ln_f,
lm_heads,
cfg,
dtype: vb.dtype(),
})
}
pub fn config(&self) -> &Config {
&self.cfg
}
pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
let device = idx.device();
let (b, _num_hierarchies, t) = idx.dims3()?;
let pos = Tensor::arange(0u32, t as u32, device)?;
let pos_emb = pos.apply(&self.wpe)?;
let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), self.dtype, device)?;
for (wte_idx, wte) in self.wtes.iter().enumerate() {
let emb = idx.i((.., wte_idx, ..))?.apply(wte)?;
tok_emb = (tok_emb + emb)?;
}
// TODO: speaker embs.
let spk_emb = 0f64;
let mut xs = (pos_emb.broadcast_add(&tok_emb)? + spk_emb)?;
for block in self.h.iter() {
xs = xs.apply(block)?
}
let xs = xs.apply(&self.ln_f)?;
let mut logits = Vec::with_capacity(self.lm_heads.len());
for lm_head in self.lm_heads.iter() {
// non-causal mode only.
let ys = xs.apply(lm_head)?;
logits.push(ys)
}
Ok(logits)
}
}
}
pub mod transformer {
use super::*;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub block_size: usize,
pub vocab_size: usize,
pub n_layer: usize,
pub n_head: usize,
pub dim: usize,
pub speaker_emb_dim: usize,
pub intermediate_size: Option<usize>,
pub n_local_heads: Option<usize>,
pub norm_eps: f64,
}
impl Config {
pub fn cfg1b_v0_1() -> Self {
Self {
n_layer: 24,
n_head: 16,
dim: 2048,
vocab_size: 2562,
speaker_emb_dim: 256,
block_size: 2048,
intermediate_size: None,
n_local_heads: None,
norm_eps: 1e-5,
}
}
fn n_local_heads(&self) -> usize {
self.n_local_heads.unwrap_or(self.n_head)
}
fn head_dim(&self) -> usize {
self.dim / self.n_head
}
fn intermediate_size(&self) -> usize {
match self.intermediate_size {
Some(intermediate_size) => intermediate_size,
None => {
let hidden_dim = self.dim * 4;
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
(n_hidden + 255) / 256 * 256
}
}
}
}
#[derive(Debug, Clone)]
struct FeedForward {
w1: Linear,
w2: Linear,
w3: Linear,
}
impl FeedForward {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let i_size = cfg.intermediate_size();
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
Ok(Self { w1, w2, w3 })
}
}
impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
swiglu.apply(&self.w2)
}
}
#[derive(Debug, Clone)]
struct Attention {
wqkv: Linear,
wo: Linear,
dim: usize,
kv_size: usize,
n_local_heads: usize,
head_dim: usize,
n_head: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let n_local_heads = cfg.n_local_heads();
let head_dim = cfg.head_dim();
let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
Ok(Self {
wqkv,
wo,
dim: cfg.dim,
kv_size: n_local_heads * head_dim,
n_local_heads,
head_dim,
n_head: cfg.n_head,
kv_cache: None,
})
}
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
let (b_sz, seqlen, _) = xs.dims3()?;
let qkv = xs.apply(&self.wqkv)?;
let q = qkv.narrow(D::Minus1, 0, self.dim)?;
let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
let q = q
.reshape((b_sz, seqlen, self.n_head, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &k], 2)?;
let v = Tensor::cat(&[prev_v, &v], 2)?;
(k, v)
}
};
self.kv_cache = Some((k.clone(), v.clone()));
let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
let attn_weights = attn_weights.broadcast_add(mask)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?;
attn_output
.transpose(1, 2)?
.reshape((b_sz, seqlen, self.dim))?
.apply(&self.wo)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct Block {
attention: Attention,
feed_forward: FeedForward,
ffn_norm: RmsNorm,
attention_norm: RmsNorm,
}
impl Block {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention = Attention::new(cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
let ffn_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
let attention_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
Ok(Self {
attention,
feed_forward,
ffn_norm,
attention_norm,
})
}
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
let hs = xs.apply(&self.attention_norm)?;
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
}
fn clear_kv_cache(&mut self) {
self.attention.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
tok_embeddings: Embedding,
pos_embeddings: Embedding,
speaker_cond_pos: Linear,
layers: Vec<Block>,
norm: RmsNorm,
output: Linear,
spk_cond_mask: Tensor,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let tok_embeddings = embedding(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
let pos_embeddings = embedding(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
let speaker_cond_pos = linear_b(
cfg.speaker_emb_dim,
cfg.dim,
false,
vb.pp("speaker_cond_pos"),
)?;
let mut layers = Vec::with_capacity(cfg.n_layer);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.n_layer {
let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
let dtype = vb.dtype();
let spk_cond_mask = Tensor::cat(
&[
Tensor::ones((1, 1, cfg.dim), dtype, vb.device())?,
Tensor::zeros((1, 1, cfg.dim), dtype, vb.device())?,
],
0,
)?;
Ok(Self {
tok_embeddings,
pos_embeddings,
speaker_cond_pos,
layers,
norm,
output,
spk_cond_mask,
})
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
let (_b_sz, seqlen) = xs.dims2()?;
let mask: Vec<_> = (0..seqlen)
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
let tok_embeddings = xs.apply(&self.tok_embeddings)?;
let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
let mut xs = tok_embeddings
.broadcast_add(&pos_embeddings)?
.broadcast_add(
&spk_emb
.apply(&self.speaker_cond_pos)?
.broadcast_mul(&self.spk_cond_mask)?,
)?;
let mask = mask.to_dtype(xs.dtype())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, pos, &mask)?
}
xs.narrow(1, seqlen - 1, 1)?
.apply(&self.norm)?
.apply(&self.output)
}
}
}
pub mod adapters {
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py
pub struct TiltedEncodec {
end_of_audio_token: u32,
}
impl TiltedEncodec {
pub fn new(end_of_audio_token: u32) -> Self {
Self { end_of_audio_token }
}
pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
let mut text_ids = vec![];
let mut extracted_audio_ids = vec![];
let mut min_audio_ids_len = usize::MAX;
for (book_id, tokens) in tokens.iter().enumerate() {
let mut audio_ids = vec![];
for &t in tokens.iter() {
#[allow(clippy::comparison_chain)]
if t > self.end_of_audio_token {
if book_id == 0 {
text_ids.push(t)
}
} else if t < self.end_of_audio_token {
audio_ids.push(t)
}
}
min_audio_ids_len = usize::min(min_audio_ids_len, audio_ids.len());
extracted_audio_ids.push(audio_ids)
}
for audio_ids in extracted_audio_ids.iter_mut() {
audio_ids.truncate(min_audio_ids_len)
}
(text_ids, extracted_audio_ids)
}
}
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4
pub struct FlattenedInterleavedEncodec2Codebook {
end_of_audio_token: u32,
}
impl FlattenedInterleavedEncodec2Codebook {
pub fn new(end_of_audio_token: u32) -> Self {
Self { end_of_audio_token }
}
pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
let mut text_ids = vec![];
let mut audio_ids1 = vec![];
let mut audio_ids2 = vec![];
for &t in tokens.iter() {
#[allow(clippy::comparison_chain)]
if t < self.end_of_audio_token {
audio_ids1.push(t)
} else if t < 2 * self.end_of_audio_token {
audio_ids2.push(t - self.end_of_audio_token)
} else {
text_ids.push(t)
}
}
(text_ids, audio_ids1, audio_ids2)
}
}
}

View File

@ -8,17 +8,13 @@ pub mod convnext;
pub mod dinov2; pub mod dinov2;
pub mod distilbert; pub mod distilbert;
pub mod efficientnet; pub mod efficientnet;
pub mod efficientvit;
pub mod encodec;
pub mod falcon; pub mod falcon;
pub mod gemma;
pub mod jina_bert; pub mod jina_bert;
pub mod llama; pub mod llama;
pub mod llama2_c; pub mod llama2_c;
pub mod llama2_c_weights; pub mod llama2_c_weights;
pub mod mamba; pub mod mamba;
pub mod marian; pub mod marian;
pub mod metavoice;
pub mod mistral; pub mod mistral;
pub mod mixformer; pub mod mixformer;
pub mod mixtral; pub mod mixtral;
@ -33,24 +29,20 @@ pub mod quantized_llama2_c;
pub mod quantized_mistral; pub mod quantized_mistral;
pub mod quantized_mixformer; pub mod quantized_mixformer;
pub mod quantized_mpt; pub mod quantized_mpt;
pub mod quantized_rwkv_v5;
pub mod quantized_rwkv_v6;
pub mod quantized_stable_lm; pub mod quantized_stable_lm;
pub mod quantized_t5; pub mod quantized_t5;
pub mod qwen2; pub mod qwen2;
pub mod repvgg; pub mod repvgg;
pub mod resnet; pub mod resnet;
pub mod rwkv_v5; pub mod rwkv_v5;
pub mod rwkv_v6;
pub mod segformer;
pub mod segment_anything; pub mod segment_anything;
pub mod stable_diffusion; pub mod stable_diffusion;
pub mod stable_lm; pub mod stable_lm;
pub mod starcoder2;
pub mod t5; pub mod t5;
pub mod trocr; pub mod trocr;
pub mod vgg; pub mod vgg;
pub mod vit; pub mod vit;
pub mod vocos;
pub mod whisper; pub mod whisper;
pub mod with_tracing; pub mod with_tracing;
pub mod wuerstchen; pub mod wuerstchen;

View File

@ -157,16 +157,16 @@ struct LayerWeights {
head_dim: usize, head_dim: usize,
cos: Tensor, cos: Tensor,
sin: Tensor, sin: Tensor,
neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>, kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span, span_attn: tracing::Span,
span_rot: tracing::Span, span_rot: tracing::Span,
span_mlp: tracing::Span, span_mlp: tracing::Span,
} }
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> { fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape(); let shape = mask.shape();
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m) Ok(m)
} }
@ -240,7 +240,7 @@ impl LayerWeights {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?; let mask = mask.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, &self.neg_inf)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax_last_dim(&att)?; let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?; let y = att.matmul(&v.contiguous()?)?;
@ -298,7 +298,6 @@ impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
@ -338,7 +337,6 @@ impl ModelWeights {
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(), cos: cos.clone(),
sin: sin.clone(), sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None, kv_cache: None,
span_attn, span_attn,
span_rot, span_rot,
@ -387,7 +385,6 @@ impl ModelWeights {
.and_then(|m| m.to_f32()) .and_then(|m| m.to_f32())
.unwrap_or(10000f32); .unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?; let tok_embeddings = tok_embeddings.dequantize(device)?;
@ -458,7 +455,6 @@ impl ModelWeights {
head_dim: embedding_length / head_count, head_dim: embedding_length / head_count,
cos: cos.clone(), cos: cos.clone(),
sin: sin.clone(), sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None, kv_cache: None,
span_attn, span_attn,
span_rot, span_rot,

View File

@ -7,7 +7,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)? xs / (xs.neg()?.exp()? + 1.0)?
} }
#[derive(Debug, Clone)]
struct CausalSelfAttention { struct CausalSelfAttention {
q_proj: Linear, q_proj: Linear,
k_proj: Linear, k_proj: Linear,
@ -16,13 +15,14 @@ struct CausalSelfAttention {
n_head: usize, n_head: usize,
n_key_value_head: usize, n_key_value_head: usize,
head_dim: usize, head_dim: usize,
cache: Cache,
} }
impl CausalSelfAttention { impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?; let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = cache.cos.i(index_pos..index_pos + seq_len)?; let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = cache.sin.i(index_pos..index_pos + seq_len)?; let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?; let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?; let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
@ -36,13 +36,7 @@ impl CausalSelfAttention {
Ok(rope) Ok(rope)
} }
fn forward( fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?; let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?; let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?; let k = self.k_proj.forward(x)?;
@ -52,15 +46,16 @@ impl CausalSelfAttention {
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos, cache)?; let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?; let mut k = self.apply_rotary_emb(&k, index_pos)?;
if cache.use_kv_cache { if self.cache.use_kv_cache {
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] { let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
} }
cache.kvs[block_idx] = Some((k.clone(), v.clone())) cache[block_idx] = Some((k.clone(), v.clone()))
} }
let k = self.repeat_kv(k)?; let k = self.repeat_kv(k)?;
@ -71,7 +66,7 @@ impl CausalSelfAttention {
let v = v.transpose(1, 2)?.contiguous()?; let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?; let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
@ -95,7 +90,7 @@ impl CausalSelfAttention {
} }
} }
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim; let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
@ -111,6 +106,7 @@ impl CausalSelfAttention {
n_head: cfg.n_heads, n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads, n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads, head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
}) })
} }
} }
@ -122,7 +118,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m) Ok(m)
} }
#[derive(Debug, Clone)]
struct Mlp { struct Mlp {
c_fc1: Linear, c_fc1: Linear,
c_fc2: Linear, c_fc2: Linear,
@ -153,7 +148,6 @@ impl Mlp {
} }
} }
#[derive(Debug, Clone)]
struct Block { struct Block {
rms_1: RmsNorm, rms_1: RmsNorm,
attn: CausalSelfAttention, attn: CausalSelfAttention,
@ -171,23 +165,17 @@ impl Block {
} }
} }
fn forward( fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
&self,
x: &Tensor,
index_pos: usize,
block_idx: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let residual = x; let residual = x;
let x = self.rms_1.forward(x)?; let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?; let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x; let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x) Ok(x)
} }
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = let post_attention_layernorm =
@ -201,7 +189,6 @@ impl Block {
} }
} }
#[derive(Debug, Clone)]
pub struct QLlama { pub struct QLlama {
wte: Embedding, wte: Embedding,
blocks: Vec<Block>, blocks: Vec<Block>,
@ -211,23 +198,23 @@ pub struct QLlama {
} }
impl QLlama { impl QLlama {
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> { pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?; let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?; let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() { for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?; x = block.forward(&x, index_pos, block_idx)?;
} }
let x = self.ln_f.forward(&x)?; let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?; let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32) logits.to_dtype(DType::F32)
} }
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> { pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers) let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap()) .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect(); .collect();
Ok(Self { Ok(Self {
wte, wte,

View File

@ -1,286 +0,0 @@
use crate::{
quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
quantized_var_builder::VarBuilder,
};
use candle::{IndexOp, Result, Tensor};
use candle_nn::{GroupNorm, LayerNorm, Module};
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
#[derive(Debug, Clone)]
struct SelfAttention {
key: Linear,
receptance: Linear,
value: Linear,
gate: Linear,
output: Linear,
ln_x: candle_nn::GroupNorm,
time_mix_key: Tensor,
time_mix_value: Tensor,
time_mix_receptance: Tensor,
time_decay: Tensor,
time_faaaa: Tensor,
time_mix_gate: Tensor,
layer_id: usize,
n_attn_heads: usize,
}
impl SelfAttention {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let attn_hidden_size = cfg.attention_hidden_size;
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
let vb_x = vb.pp("ln_x");
let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
let ln_x = GroupNorm::new(
ln_x_weight,
ln_x_bias,
hidden_size,
hidden_size / cfg.head_size,
1e-5,
)?;
let time_mix_key = vb
.get((1, 1, cfg.hidden_size), "time_mix_key")?
.dequantize(vb.device())?;
let time_mix_value = vb
.get((1, 1, cfg.hidden_size), "time_mix_value")?
.dequantize(vb.device())?;
let time_mix_receptance = vb
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
.dequantize(vb.device())?;
let n_attn_heads = cfg.hidden_size / cfg.head_size;
let time_decay = vb
.get((n_attn_heads, cfg.head_size), "time_decay")?
.dequantize(vb.device())?;
let time_faaaa = vb
.get((n_attn_heads, cfg.head_size), "time_faaaa")?
.dequantize(vb.device())?;
let time_mix_gate = vb
.get((1, 1, cfg.hidden_size), "time_mix_gate")?
.dequantize(vb.device())?;
Ok(Self {
key,
value,
receptance,
gate,
output,
ln_x,
time_mix_key,
time_mix_value,
time_mix_receptance,
time_decay,
time_faaaa,
time_mix_gate,
layer_id,
n_attn_heads,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let h = self.time_decay.dim(0)?;
let (b, t, s) = xs.dims3()?;
let s = s / h;
let (receptance, key, value, gate) = {
// extract key-value
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
let shifted = if shifted.rank() == 2 {
shifted.unsqueeze(1)?
} else {
shifted
};
let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
let receptance = ((xs * &self.time_mix_receptance)?
+ &shifted * (1.0 - &self.time_mix_receptance)?)?;
let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
let key = self.key.forward(&key)?;
let value = self.value.forward(&value)?;
let receptance = self.receptance.forward(&receptance)?;
let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
(receptance, key, value, gate)
};
// linear attention
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
let time_decay = self
.time_decay
.exp()?
.neg()?
.exp()?
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let time_faaaa =
self.time_faaaa
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
state_ = (&at + time_decay.broadcast_mul(&state_))?;
out.push(out_)
}
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
let out = (out * gate)?.apply(&self.output)?;
state.per_layer[self.layer_id].linear_attention = state_;
Ok(out)
}
}
#[derive(Debug, Clone)]
struct FeedForward {
time_mix_key: Tensor,
time_mix_receptance: Tensor,
key: Linear,
receptance: Linear,
value: Linear,
layer_id: usize,
}
impl FeedForward {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let int_size = cfg
.intermediate_size
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
let time_mix_key = vb
.get((1, 1, cfg.hidden_size), "time_mix_key")?
.dequantize(vb.device())?;
let time_mix_receptance = vb
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
.dequantize(vb.device())?;
Ok(Self {
key,
receptance,
value,
time_mix_key,
time_mix_receptance,
layer_id,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let shifted = &state.per_layer[self.layer_id].feed_forward;
let key = (xs.broadcast_mul(&self.time_mix_key)?
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
let key = key.apply(&self.key)?.relu()?.sqr()?;
let value = key.apply(&self.value)?;
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
let xs = (receptance * value)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct Block {
pre_ln: Option<LayerNorm>,
ln1: LayerNorm,
ln2: LayerNorm,
attention: SelfAttention,
feed_forward: FeedForward,
}
impl Block {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
let pre_ln = if layer_id == 0 {
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
Some(ln)
} else {
None
};
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
Ok(Self {
pre_ln,
ln1,
ln2,
attention,
feed_forward,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let xs = match self.pre_ln.as_ref() {
None => xs.clone(),
Some(pre_ln) => xs.apply(pre_ln)?,
};
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
let xs = (xs + attention)?;
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
let xs = (xs + feed_forward)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct Model {
embeddings: Embedding,
blocks: Vec<Block>,
ln_out: LayerNorm,
head: Linear,
rescale_every: usize,
layers_are_rescaled: bool,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("rwkv");
let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
let vb_b = vb_m.pp("blocks");
for block_index in 0..cfg.num_hidden_layers {
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
blocks.push(block)
}
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
Ok(Self {
embeddings,
blocks,
ln_out,
head,
rescale_every: cfg.rescale_every,
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let (_b_size, _seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embeddings)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
xs = block.forward(&xs, state)?;
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
xs = (xs / 2.)?
}
}
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
state.pos += 1;
Ok(xs)
}
}

View File

@ -1,332 +0,0 @@
use crate::{
quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
quantized_var_builder::VarBuilder,
};
use candle::{IndexOp, Result, Tensor};
use candle_nn::{GroupNorm, LayerNorm, Module};
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
#[derive(Debug, Clone)]
struct SelfAttention {
key: Linear,
receptance: Linear,
value: Linear,
gate: Linear,
output: Linear,
ln_x: candle_nn::GroupNorm,
time_mix_x: Tensor,
time_mix_w: Tensor,
time_mix_key: Tensor,
time_mix_value: Tensor,
time_mix_receptance: Tensor,
time_decay: Tensor,
time_faaaa: Tensor,
time_mix_gate: Tensor,
time_decay_w1: Tensor,
time_decay_w2: Tensor,
time_mix_w1: Tensor,
time_mix_w2: Tensor,
layer_id: usize,
n_attn_heads: usize,
}
impl SelfAttention {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let attn_hidden_size = cfg.attention_hidden_size;
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
let vb_x = vb.pp("ln_x");
let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
let ln_x = GroupNorm::new(
ln_x_weight,
ln_x_bias,
hidden_size,
hidden_size / cfg.head_size,
1e-5,
)?;
let time_mix_x = vb
.get((1, 1, cfg.hidden_size), "time_mix_x")?
.dequantize(vb.device())?;
let time_mix_w = vb
.get((1, 1, cfg.hidden_size), "time_mix_w")?
.dequantize(vb.device())?;
let time_mix_key = vb
.get((1, 1, cfg.hidden_size), "time_mix_key")?
.dequantize(vb.device())?;
let time_mix_value = vb
.get((1, 1, cfg.hidden_size), "time_mix_value")?
.dequantize(vb.device())?;
let time_mix_receptance = vb
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
.dequantize(vb.device())?;
let n_attn_heads = cfg.hidden_size / cfg.head_size;
let time_decay = vb
.get((1, 1, cfg.hidden_size), "time_decay")?
.dequantize(vb.device())?;
let time_faaaa = vb
.get((n_attn_heads, cfg.head_size), "time_faaaa")?
.dequantize(vb.device())?;
let time_mix_gate = vb
.get((1, 1, cfg.hidden_size), "time_mix_gate")?
.dequantize(vb.device())?;
let time_decay_w1 = vb
.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?
.dequantize(vb.device())?;
let time_decay_w2 = vb
.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?
.dequantize(vb.device())?;
let time_mix_w1 = vb
.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?
.dequantize(vb.device())?;
let time_mix_w2 = vb
.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?
.dequantize(vb.device())?;
Ok(Self {
key,
value,
receptance,
gate,
output,
ln_x,
time_mix_x,
time_mix_w,
time_mix_key,
time_mix_value,
time_mix_receptance,
time_decay,
time_faaaa,
time_mix_gate,
time_decay_w1,
time_decay_w2,
time_mix_w1,
time_mix_w2,
layer_id,
n_attn_heads,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let h = self.n_attn_heads;
let (b, t, s) = xs.dims3()?;
let s = s / h;
let (receptance, key, value, gate, w) = {
// extract key-value
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
let shifted = if shifted.rank() == 2 {
shifted.unsqueeze(1)?
} else {
shifted
};
let sx = (&shifted - xs)?;
let xxx = (xs + &sx * &self.time_mix_x)?;
let xxx = xxx
.broadcast_matmul(&self.time_mix_w1)?
.tanh()?
.reshape((b * t, 5, ()))?
.transpose(0, 1)?;
let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
let w = (&self.time_decay
+ xw.broadcast_matmul(&self.time_decay_w1)?
.tanh()?
.broadcast_matmul(&self.time_decay_w2)?)?
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let key = self.key.forward(&xk)?;
let value = self.value.forward(&xv)?;
let receptance = self.receptance.forward(&xr)?;
let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
(receptance, key, value, gate, w)
};
// linear attention
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
let w = w.exp()?.neg()?.exp()?;
let time_faaaa =
self.time_faaaa
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
state_ = (&at + w.broadcast_mul(&state_))?;
out.push(out_)
}
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
let out = (out * gate)?.apply(&self.output)?;
state.per_layer[self.layer_id].linear_attention = state_;
Ok(out)
}
}
#[derive(Debug, Clone)]
struct FeedForward {
time_mix_key: Tensor,
time_mix_receptance: Tensor,
key: Linear,
receptance: Linear,
value: Linear,
layer_id: usize,
}
impl FeedForward {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let int_size = cfg
.intermediate_size
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
let time_mix_key = vb
.get((1, 1, cfg.hidden_size), "time_mix_key")?
.dequantize(vb.device())?;
let time_mix_receptance = vb
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
.dequantize(vb.device())?;
Ok(Self {
key,
receptance,
value,
time_mix_key,
time_mix_receptance,
layer_id,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let shifted = state.per_layer[self.layer_id]
.feed_forward
.broadcast_sub(xs)?;
let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
let key = key.apply(&self.key)?.relu()?.sqr()?;
let value = key.apply(&self.value)?;
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
let xs = (receptance * value)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct Block {
pre_ln: Option<LayerNorm>,
ln1: LayerNorm,
ln2: LayerNorm,
attention: SelfAttention,
feed_forward: FeedForward,
}
impl Block {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
let pre_ln = if layer_id == 0 {
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
Some(ln)
} else {
None
};
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
Ok(Self {
pre_ln,
ln1,
ln2,
attention,
feed_forward,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let xs = match self.pre_ln.as_ref() {
None => xs.clone(),
Some(pre_ln) => xs.apply(pre_ln)?,
};
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
let xs = (xs + attention)?;
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
let xs = (xs + feed_forward)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct Model {
embeddings: Embedding,
blocks: Vec<Block>,
ln_out: LayerNorm,
head: Linear,
rescale_every: usize,
layers_are_rescaled: bool,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("rwkv");
let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
let vb_b = vb_m.pp("blocks");
for block_index in 0..cfg.num_hidden_layers {
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
blocks.push(block)
}
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
Ok(Self {
embeddings,
blocks,
ln_out,
head,
rescale_every: cfg.rescale_every,
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let (_b_size, _seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embeddings)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
xs = block.forward(&xs, state)?;
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
xs = (xs / 2.)?
}
}
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
state.pos += 1;
Ok(xs)
}
}

View File

@ -186,14 +186,10 @@ impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm = layer_norm( let input_layernorm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?;
cfg.hidden_size,
cfg.layer_norm_eps,
vb.pp("input_layernorm"),
)?;
let post_attention_layernorm = layer_norm( let post_attention_layernorm = layer_norm(
cfg.hidden_size, cfg.hidden_size,
cfg.layer_norm_eps, cfg.norm_eps,
vb.pp("post_attention_layernorm"), vb.pp("post_attention_layernorm"),
)?; )?;
Ok(Self { Ok(Self {
@ -244,7 +240,7 @@ impl Model {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer) layers.push(layer)
} }
let norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?; let norm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self { Ok(Self {
embed_tokens, embed_tokens,

View File

@ -22,15 +22,15 @@ pub struct Config {
pub rescale_every: usize, pub rescale_every: usize,
} }
pub struct StatePerLayer { struct StatePerLayer {
pub extract_key_value: Tensor, extract_key_value: Tensor,
pub linear_attention: Tensor, linear_attention: Tensor,
pub feed_forward: Tensor, feed_forward: Tensor,
} }
pub struct State { pub struct State {
pub per_layer: Vec<StatePerLayer>, per_layer: Vec<StatePerLayer>,
pub pos: usize, pos: usize,
} }
impl State { impl State {
@ -124,7 +124,7 @@ impl SelfAttention {
let (b, t, s) = xs.dims3()?; let (b, t, s) = xs.dims3()?;
let s = s / h; let s = s / h;
let (receptance, key, value, gate) = { let (receptance, key, value, gate) = {
// extract key-value // exctract key-value
let shifted = state.per_layer[self.layer_id].extract_key_value.clone(); let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
let shifted = if shifted.rank() == 2 { let shifted = if shifted.rank() == 2 {
shifted.unsqueeze(1)? shifted.unsqueeze(1)?
@ -164,9 +164,10 @@ impl SelfAttention {
let mut out: Vec<Tensor> = Vec::with_capacity(t); let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t { for t_ in 0..t {
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?; //
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?; let rt = receptance.i((.., .., t_..t_ + 1))?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?; let kt = key.i((.., .., .., t_..t_ + 1))?;
let vt = value.i((.., .., t_..t_ + 1))?;
let at = kt.matmul(&vt)?; let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?; let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?; let out_ = rt.matmul(&rhs)?.squeeze(2)?;

View File

@ -1,295 +0,0 @@
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
use candle::{IndexOp, Result, Tensor};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
#[derive(Debug, Clone)]
struct SelfAttention {
key: Linear,
receptance: Linear,
value: Linear,
gate: Linear,
output: Linear,
ln_x: candle_nn::GroupNorm,
time_mix_x: Tensor,
time_mix_w: Tensor,
time_mix_key: Tensor,
time_mix_value: Tensor,
time_mix_receptance: Tensor,
time_decay: Tensor,
time_faaaa: Tensor,
time_mix_gate: Tensor,
time_decay_w1: Tensor,
time_decay_w2: Tensor,
time_mix_w1: Tensor,
time_mix_w2: Tensor,
layer_id: usize,
n_attn_heads: usize,
}
impl SelfAttention {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let attn_hidden_size = cfg.attention_hidden_size;
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
let ln_x = candle_nn::group_norm(
hidden_size / cfg.head_size,
hidden_size,
1e-5,
vb.pp("ln_x"),
)?;
let time_mix_x = vb.get((1, 1, cfg.hidden_size), "time_mix_x")?;
let time_mix_w = vb.get((1, 1, cfg.hidden_size), "time_mix_w")?;
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
let n_attn_heads = cfg.hidden_size / cfg.head_size;
let time_decay = vb.get((1, 1, cfg.hidden_size), "time_decay")?;
let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
let time_decay_w1 = vb.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?;
let time_decay_w2 = vb.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?;
let time_mix_w1 = vb.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?;
let time_mix_w2 = vb.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?;
Ok(Self {
key,
value,
receptance,
gate,
output,
ln_x,
time_mix_x,
time_mix_w,
time_mix_key,
time_mix_value,
time_mix_receptance,
time_decay,
time_faaaa,
time_mix_gate,
time_decay_w1,
time_decay_w2,
time_mix_w1,
time_mix_w2,
layer_id,
n_attn_heads,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let h = self.n_attn_heads;
let (b, t, s) = xs.dims3()?;
let s = s / h;
let (receptance, key, value, gate, w) = {
// extract key-value
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
let shifted = if shifted.rank() == 2 {
shifted.unsqueeze(1)?
} else {
shifted
};
let sx = (&shifted - xs)?;
let xxx = (xs + &sx * &self.time_mix_x)?;
let xxx = xxx
.broadcast_matmul(&self.time_mix_w1)?
.tanh()?
.reshape((b * t, 5, ()))?
.transpose(0, 1)?;
let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
let w = (&self.time_decay
+ xw.broadcast_matmul(&self.time_decay_w1)?
.tanh()?
.broadcast_matmul(&self.time_decay_w2)?)?
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let key = self.key.forward(&xk)?;
let value = self.value.forward(&xv)?;
let receptance = self.receptance.forward(&xr)?;
let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
(receptance, key, value, gate, w)
};
// linear attention
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
let w = w.exp()?.neg()?.exp()?;
let time_faaaa =
self.time_faaaa
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
state_ = (&at + w.broadcast_mul(&state_))?;
out.push(out_)
}
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
let out = (out * gate)?.apply(&self.output)?;
state.per_layer[self.layer_id].linear_attention = state_;
Ok(out)
}
}
#[derive(Debug, Clone)]
struct FeedForward {
time_mix_key: Tensor,
time_mix_receptance: Tensor,
key: Linear,
receptance: Linear,
value: Linear,
layer_id: usize,
}
impl FeedForward {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let int_size = cfg
.intermediate_size
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
Ok(Self {
key,
receptance,
value,
time_mix_key,
time_mix_receptance,
layer_id,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let shifted = state.per_layer[self.layer_id]
.feed_forward
.broadcast_sub(xs)?;
let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
let key = key.apply(&self.key)?.relu()?.sqr()?;
let value = key.apply(&self.value)?;
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
let xs = (receptance * value)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct Block {
pre_ln: Option<LayerNorm>,
ln1: LayerNorm,
ln2: LayerNorm,
attention: SelfAttention,
feed_forward: FeedForward,
}
impl Block {
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
let pre_ln = if layer_id == 0 {
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
Some(ln)
} else {
None
};
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
Ok(Self {
pre_ln,
ln1,
ln2,
attention,
feed_forward,
})
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let xs = match self.pre_ln.as_ref() {
None => xs.clone(),
Some(pre_ln) => xs.apply(pre_ln)?,
};
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
let xs = (xs + attention)?;
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
let xs = (xs + feed_forward)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct Model {
embeddings: Embedding,
blocks: Vec<Block>,
ln_out: LayerNorm,
head: Linear,
rescale_every: usize,
layers_are_rescaled: bool,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("rwkv");
let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
let vb_b = vb_m.pp("blocks");
for block_index in 0..cfg.num_hidden_layers {
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
blocks.push(block)
}
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
Ok(Self {
embeddings,
blocks,
ln_out,
head,
rescale_every: cfg.rescale_every,
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let (_b_size, _seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embeddings)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
xs = block.forward(&xs, state)?;
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
xs = (xs / 2.)?
}
}
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
state.pos += 1;
Ok(xs)
}
}

View File

@ -1,705 +0,0 @@
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, ModuleT, Result, Tensor, D};
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
#[serde(default)]
pub id2label: HashMap<String, String>,
pub num_channels: usize,
pub num_encoder_blocks: usize,
pub depths: Vec<usize>,
pub sr_ratios: Vec<usize>,
pub hidden_sizes: Vec<usize>,
pub patch_sizes: Vec<usize>,
pub strides: Vec<usize>,
pub num_attention_heads: Vec<usize>,
pub mlp_ratios: Vec<usize>,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
pub decoder_hidden_size: usize,
}
#[derive(Debug, Clone)]
struct SegformerOverlapPatchEmbeddings {
projection: Conv2d,
layer_norm: candle_nn::LayerNorm,
}
impl SegformerOverlapPatchEmbeddings {
fn new(
config: &Config,
patch_size: usize,
stride: usize,
num_channels: usize,
hidden_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let projection = conv2d(
num_channels,
hidden_size,
patch_size,
Conv2dConfig {
stride,
padding: patch_size / 2,
..Default::default()
},
vb.pp("proj"),
)?;
let layer_norm =
candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?;
Ok(Self {
projection,
layer_norm,
})
}
}
impl Module for SegformerOverlapPatchEmbeddings {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let embeddings = self.projection.forward(x)?;
let shape = embeddings.shape();
// [B, C, H, W] -> [B, H * W, C]
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
let embeddings = self.layer_norm.forward(&embeddings)?;
// [B, H * W, C] -> [B, C, H, W]
let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?;
Ok(embeddings)
}
}
#[derive(Debug, Clone)]
struct SegformerEfficientSelfAttention {
num_attention_heads: usize,
attention_head_size: usize,
query: Linear,
key: Linear,
value: Linear,
sr: Option<Conv2d>,
layer_norm: Option<layer_norm::LayerNorm>,
}
impl SegformerEfficientSelfAttention {
fn new(
config: &Config,
hidden_size: usize,
num_attention_heads: usize,
sequence_reduction_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
if hidden_size % num_attention_heads != 0 {
candle::bail!(
"The hidden size {} is not a multiple of the number of attention heads {}",
hidden_size,
num_attention_heads
)
}
let attention_head_size = hidden_size / num_attention_heads;
let all_head_size = num_attention_heads * attention_head_size;
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
let (sr, layer_norm) = if sequence_reduction_ratio > 1 {
(
Some(conv2d(
hidden_size,
hidden_size,
sequence_reduction_ratio,
Conv2dConfig {
stride: sequence_reduction_ratio,
..Default::default()
},
vb.pp("sr"),
)?),
Some(candle_nn::layer_norm(
hidden_size,
config.layer_norm_eps,
vb.pp("layer_norm"),
)?),
)
} else {
(None, None)
};
Ok(Self {
num_attention_heads,
attention_head_size,
query,
key,
value,
sr,
layer_norm,
})
}
fn transpose_for_scores(&self, hidden_states: Tensor) -> Result<Tensor> {
let (batch, seq_length, _) = hidden_states.shape().dims3()?;
let new_shape = &[
batch,
seq_length,
self.num_attention_heads,
self.attention_head_size,
];
let hidden_states = hidden_states.reshape(new_shape)?;
let hidden_states = hidden_states.permute((0, 2, 1, 3))?;
Ok(hidden_states)
}
}
impl Module for SegformerEfficientSelfAttention {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// [B, C, H, W] -> [B, H * W, C]
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
let query = self
.transpose_for_scores(self.query.forward(&hidden_states)?)?
.contiguous()?;
let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) {
let hidden_states = sr.forward(x)?;
// [B, C, H, W] -> [B, H * W, C]
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
layer_norm.forward(&hidden_states)?
} else {
// already [B, H * W, C]
hidden_states
};
// standard self-attention
let key = self
.transpose_for_scores(self.key.forward(&hidden_states)?)?
.contiguous()?;
let value = self
.transpose_for_scores(self.value.forward(&hidden_states)?)?
.contiguous()?;
let attention_scores =
(query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let result = attention_scores.matmul(&value)?;
let result = result.permute((0, 2, 1, 3))?.contiguous()?;
result.flatten_from(D::Minus2)
}
}
#[derive(Debug, Clone)]
struct SegformerSelfOutput {
dense: Linear,
}
impl SegformerSelfOutput {
fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?;
Ok(Self { dense })
}
}
impl Module for SegformerSelfOutput {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.dense.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerAttention {
attention: SegformerEfficientSelfAttention,
output: SegformerSelfOutput,
}
impl SegformerAttention {
fn new(
config: &Config,
hidden_size: usize,
num_attention_heads: usize,
sequence_reduction_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
let attention = SegformerEfficientSelfAttention::new(
config,
hidden_size,
num_attention_heads,
sequence_reduction_ratio,
vb.pp("self"),
)?;
let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?;
Ok(Self { attention, output })
}
}
impl Module for SegformerAttention {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let attention_output = self.attention.forward(x)?;
self.output.forward(&attention_output)
}
}
#[derive(Debug, Clone)]
struct SegformerDWConv {
dw_conv: Conv2d,
}
impl SegformerDWConv {
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
let dw_conv = conv2d(
dim,
dim,
3,
Conv2dConfig {
stride: 1,
padding: 1,
groups: dim,
..Default::default()
},
vb.pp("dwconv"),
)?;
Ok(Self { dw_conv })
}
}
impl Module for SegformerDWConv {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.dw_conv.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerMixFFN {
dense1: Linear,
dw_conv: SegformerDWConv,
act: Activation,
dense2: Linear,
}
impl SegformerMixFFN {
fn new(
config: &Config,
in_features: usize,
hidden_features: usize,
out_features: usize,
vb: VarBuilder,
) -> Result<Self> {
let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?;
let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?;
let act = config.hidden_act;
let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?;
Ok(Self {
dense1,
dw_conv,
act,
dense2,
})
}
}
impl Module for SegformerMixFFN {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (batch, _, height, width) = x.shape().dims4()?;
let hidden_states = self
.dense1
.forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?;
let channels = hidden_states.dim(2)?;
let hidden_states = self.dw_conv.forward(
&hidden_states
.permute((0, 2, 1))?
.reshape((batch, channels, height, width))?,
)?;
let hidden_states = self.act.forward(&hidden_states)?;
let hidden_states = self
.dense2
.forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
let channels = hidden_states.dim(2)?;
hidden_states
.permute((0, 2, 1))?
.reshape((batch, channels, height, width))
}
}
#[derive(Debug, Clone)]
struct SegformerLayer {
layer_norm_1: candle_nn::LayerNorm,
attention: SegformerAttention,
layer_norm_2: candle_nn::LayerNorm,
mlp: SegformerMixFFN,
}
impl SegformerLayer {
fn new(
config: &Config,
hidden_size: usize,
num_attention_heads: usize,
sequence_reduction_ratio: usize,
mlp_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?;
let attention = SegformerAttention::new(
config,
hidden_size,
num_attention_heads,
sequence_reduction_ratio,
vb.pp("attention"),
)?;
let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?;
let mlp = SegformerMixFFN::new(
config,
hidden_size,
hidden_size * mlp_ratio,
hidden_size,
vb.pp("mlp"),
)?;
Ok(Self {
layer_norm_1,
attention,
layer_norm_2,
mlp,
})
}
}
impl Module for SegformerLayer {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let shape = x.shape().dims4()?;
// [B, C, H, W] -> [B, H * W, C]
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
let layer_norm_output = self.layer_norm_1.forward(&hidden_states)?;
let layer_norm_output = layer_norm_output.permute((0, 2, 1))?.reshape(shape)?;
// attention takes in [B, C, H, W] in order to properly do conv2d (and output [B, H * W, C])
let attention_output = self.attention.forward(&layer_norm_output)?;
let hidden_states = (attention_output + hidden_states)?;
let layer_norm_output = self.layer_norm_2.forward(&hidden_states)?;
let mlp_output = self
.mlp
.forward(&layer_norm_output.permute((0, 2, 1))?.reshape(shape)?)?;
hidden_states.permute((0, 2, 1))?.reshape(shape)? + mlp_output
}
}
#[derive(Debug, Clone)]
struct SegformerEncoder {
/// config file
config: Config,
/// a list of embeddings
patch_embeddings: Vec<SegformerOverlapPatchEmbeddings>,
/// a list of attention blocks, each consisting of layers
blocks: Vec<Vec<SegformerLayer>>,
/// a final list of layer norms
layer_norms: Vec<candle_nn::LayerNorm>,
}
impl SegformerEncoder {
fn new(config: Config, vb: VarBuilder) -> Result<Self> {
let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks);
let mut blocks = Vec::with_capacity(config.num_encoder_blocks);
let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks);
for i in 0..config.num_encoder_blocks {
let patch_size = config.patch_sizes[i];
let stride = config.strides[i];
let hidden_size = config.hidden_sizes[i];
let num_channels = if i == 0 {
config.num_channels
} else {
config.hidden_sizes[i - 1]
};
patch_embeddings.push(SegformerOverlapPatchEmbeddings::new(
&config,
patch_size,
stride,
num_channels,
hidden_size,
vb.pp(&format!("patch_embeddings.{}", i)),
)?);
let mut layers = Vec::with_capacity(config.depths[i]);
for j in 0..config.depths[i] {
let sequence_reduction_ratio = config.sr_ratios[i];
let num_attention_heads = config.num_attention_heads[i];
let mlp_ratio = config.mlp_ratios[i];
layers.push(SegformerLayer::new(
&config,
hidden_size,
num_attention_heads,
sequence_reduction_ratio,
mlp_ratio,
vb.pp(&format!("block.{}.{}", i, j)),
)?);
}
blocks.push(layers);
layer_norms.push(layer_norm(
hidden_size,
config.layer_norm_eps,
vb.pp(&format!("layer_norm.{}", i)),
)?);
}
Ok(Self {
config,
patch_embeddings,
blocks,
layer_norms,
})
}
}
impl ModuleWithHiddenStates for SegformerEncoder {
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks);
let mut hidden_states = x.clone();
for i in 0..self.config.num_encoder_blocks {
hidden_states = self.patch_embeddings[i].forward(&hidden_states)?;
for layer in &self.blocks[i] {
hidden_states = layer.forward(&hidden_states)?;
}
let shape = hidden_states.shape().dims4()?;
hidden_states =
self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
hidden_states = hidden_states.permute((0, 2, 1))?.reshape(shape)?;
all_hidden_states.push(hidden_states.clone());
}
Ok(all_hidden_states)
}
}
#[derive(Debug, Clone)]
struct SegformerModel {
encoder: SegformerEncoder,
}
impl SegformerModel {
fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?;
Ok(Self { encoder })
}
}
impl ModuleWithHiddenStates for SegformerModel {
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
self.encoder.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerMLP {
proj: Linear,
}
impl SegformerMLP {
fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result<Self> {
let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("proj"))?;
Ok(Self { proj })
}
}
impl Module for SegformerMLP {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.proj.forward(x)
}
}
#[derive(Debug, Clone)]
struct SegformerDecodeHead {
linear_c: Vec<SegformerMLP>,
linear_fuse: candle_nn::Conv2d,
batch_norm: candle_nn::BatchNorm,
classifier: candle_nn::Conv2d,
}
impl SegformerDecodeHead {
fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let mut linear_c = Vec::with_capacity(config.num_encoder_blocks);
for i in 0..config.num_encoder_blocks {
let hidden_size = config.hidden_sizes[i];
linear_c.push(SegformerMLP::new(
config,
hidden_size,
vb.pp(&format!("linear_c.{}", i)),
)?);
}
let linear_fuse = conv2d_no_bias(
config.decoder_hidden_size * config.num_encoder_blocks,
config.decoder_hidden_size,
1,
Conv2dConfig::default(),
vb.pp("linear_fuse"),
)?;
let batch_norm = candle_nn::batch_norm(
config.decoder_hidden_size,
config.layer_norm_eps,
vb.pp("batch_norm"),
)?;
let classifier = conv2d_no_bias(
config.decoder_hidden_size,
num_labels,
1,
Conv2dConfig::default(),
vb.pp("classifier"),
)?;
Ok(Self {
linear_c,
linear_fuse,
batch_norm,
classifier,
})
}
fn forward(&self, encoder_hidden_states: &[Tensor]) -> Result<Tensor> {
if encoder_hidden_states.len() != self.linear_c.len() {
candle::bail!(
"The number of encoder hidden states {} is not equal to the number of linear layers {}",
encoder_hidden_states.len(),
self.linear_c.len()
)
}
// most fine layer
let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?;
let mut hidden_states = Vec::with_capacity(self.linear_c.len());
for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) {
let (batch, _, height, width) = hidden_state.shape().dims4()?;
let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?;
let hidden_state = hidden_state.permute((0, 2, 1))?.reshape((
batch,
hidden_state.dim(2)?,
height,
width,
))?;
let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?;
hidden_states.push(hidden_state);
}
hidden_states.reverse();
let hidden_states = Tensor::cat(&hidden_states, 1)?;
let hidden_states = self.linear_fuse.forward(&hidden_states)?;
let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?;
let hidden_states = hidden_states.relu()?;
self.classifier.forward(&hidden_states)
}
}
trait ModuleWithHiddenStates {
fn forward(&self, xs: &Tensor) -> Result<Vec<Tensor>>;
}
#[derive(Debug, Clone)]
pub struct SemanticSegmentationModel {
segformer: SegformerModel,
decode_head: SegformerDecodeHead,
}
impl SemanticSegmentationModel {
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?;
Ok(Self {
segformer,
decode_head,
})
}
}
impl Module for SemanticSegmentationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let hidden_states = self.segformer.forward(x)?;
self.decode_head.forward(&hidden_states)
}
}
#[derive(Debug, Clone)]
pub struct ImageClassificationModel {
segformer: SegformerModel,
classifier: Linear,
}
impl ImageClassificationModel {
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
let classifier = linear(config.decoder_hidden_size, num_labels, vb.pp("classifier"))?;
Ok(Self {
segformer,
classifier,
})
}
}
impl Module for ImageClassificationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let all_hidden_states = self.segformer.forward(x)?;
let hidden_states = all_hidden_states.last().unwrap();
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
let mean = hidden_states.mean(1)?;
self.classifier.forward(&mean)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_json_load() {
let raw_json = r#"{
"architectures": [
"SegformerForImageClassification"
],
"attention_probs_dropout_prob": 0.0,
"classifier_dropout_prob": 0.1,
"decoder_hidden_size": 256,
"depths": [
2,
2,
2,
2
],
"downsampling_rates": [
1,
4,
8,
16
],
"drop_path_rate": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_sizes": [
32,
64,
160,
256
],
"image_size": 224,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"mlp_ratios": [
4,
4,
4,
4
],
"model_type": "segformer",
"num_attention_heads": [
1,
2,
5,
8
],
"num_channels": 3,
"num_encoder_blocks": 4,
"patch_sizes": [
7,
3,
3,
3
],
"sr_ratios": [
8,
4,
2,
1
],
"strides": [
4,
2,
2,
2
],
"torch_dtype": "float32",
"transformers_version": "4.12.0.dev0"
}"#;
let config: Config = serde_json::from_str(raw_json).unwrap();
assert_eq!(vec![4, 2, 2, 2], config.strides);
assert_eq!(1e-6, config.layer_norm_eps);
}
}

View File

@ -4,7 +4,7 @@ use candle_nn::{Activation, LayerNorm, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm.py // https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
#[derive(Debug, Clone, PartialEq, Deserialize)] #[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config { pub struct Config {
pub(crate) vocab_size: usize, pub(crate) vocab_size: usize,
@ -14,10 +14,10 @@ pub struct Config {
pub(crate) num_attention_heads: usize, pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: usize, pub(crate) num_key_value_heads: usize,
pub(crate) hidden_act: Activation, pub(crate) hidden_act: Activation,
pub(crate) partial_rotary_factor: f64, pub(crate) rope_pct: f64,
pub(crate) rope_theta: f64, pub(crate) rope_theta: f64,
pub(crate) max_position_embeddings: usize, pub(crate) max_position_embeddings: usize,
pub(crate) layer_norm_eps: f64, pub(crate) norm_eps: f64,
pub(crate) use_cache: bool, pub(crate) use_cache: bool,
#[serde(default)] #[serde(default)]
pub(crate) use_qkv_bias: bool, // Used in StableLM-2 pub(crate) use_qkv_bias: bool, // Used in StableLM-2
@ -35,10 +35,10 @@ impl Config {
num_attention_heads: 32, num_attention_heads: 32,
num_key_value_heads: 32, num_key_value_heads: 32,
hidden_act: Activation::Silu, hidden_act: Activation::Silu,
partial_rotary_factor: 0.25, rope_pct: 0.25,
rope_theta: 10_000., rope_theta: 10_000.,
max_position_embeddings: 4096, max_position_embeddings: 4096,
layer_norm_eps: 1e-5, norm_eps: 1e-5,
use_qkv_bias: false, use_qkv_bias: false,
use_cache: true, use_cache: true,
use_flash_attn, use_flash_attn,
@ -50,7 +50,7 @@ impl Config {
} }
pub fn rotary_ndims(&self) -> usize { pub fn rotary_ndims(&self) -> usize {
(self.head_dim() as f64 * self.partial_rotary_factor) as usize (self.head_dim() as f64 * self.rope_pct) as usize
} }
pub fn num_kv_groups(&self) -> usize { pub fn num_kv_groups(&self) -> usize {
@ -316,14 +316,11 @@ impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> { fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm = candle_nn::layer_norm( let input_layernorm =
cfg.hidden_size, candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?;
cfg.layer_norm_eps,
vb.pp("input_layernorm"),
)?;
let post_attention_layernorm = candle_nn::layer_norm( let post_attention_layernorm = candle_nn::layer_norm(
cfg.hidden_size, cfg.hidden_size,
cfg.layer_norm_eps, cfg.norm_eps,
vb.pp("post_attention_layernorm"), vb.pp("post_attention_layernorm"),
)?; )?;
Ok(Self { Ok(Self {
@ -375,7 +372,7 @@ impl Model {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer) layers.push(layer)
} }
let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?; let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self { Ok(Self {
embed_tokens, embed_tokens,

View File

@ -1,347 +0,0 @@
#![allow(unused)]
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{layer_norm, linear_b, LayerNorm, Linear, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
vocab_size: usize,
hidden_size: usize,
intermediate_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
num_key_value_heads: usize,
hidden_act: candle_nn::Activation,
max_position_embeddings: usize,
norm_epsilon: f64,
rope_theta: f64,
use_bias: bool,
sliding_window: Option<usize>,
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
c_fc: Linear,
c_proj: Linear,
act: candle_nn::Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
let c_fc = linear_b(h_size, i_size, cfg.use_bias, vb.pp("c_fc"))?;
let c_proj = linear_b(i_size, h_size, cfg.use_bias, vb.pp("c_proj"))?;
Ok(Self {
c_fc,
c_proj,
act: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.c_fc)?.apply(&self.act)?.apply(&self.c_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let b = cfg.use_bias;
let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size: hidden_sz,
rotary_emb,
kv_cache: None,
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&value_states)?;
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: LayerNorm,
post_attention_layernorm: LayerNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
layer_norm(cfg.hidden_size, cfg.norm_epsilon, vb.pp("input_layernorm"))?;
let post_attention_layernorm = layer_norm(
cfg.hidden_size,
cfg.norm_epsilon,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: LayerNorm,
lm_head: Linear,
sliding_window: Option<usize>,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = layer_norm(cfg.hidden_size, cfg.norm_epsilon, vb_m.pp("norm"))?;
let lm_head = candle_nn::Linear::new(embed_tokens.embeddings().clone(), None);
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.sliding_window,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let sliding_window = self.sliding_window.unwrap_or(tgt_len + 42);
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -0,0 +1,156 @@
#![allow(unused)]
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
pub struct AdaLayerNorm {
eps: f64,
dim: usize,
scale: Embedding,
shift: Embedding,
}
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = {
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
};
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)
}
impl AdaLayerNorm {
pub fn new(
num_embeddings: usize,
embedding_dim: usize,
eps: f64,
vb: VarBuilder,
) -> Result<Self> {
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
Ok(Self {
eps,
dim: embedding_dim,
scale,
shift,
})
}
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
let scale = self.scale.forward(cond_embedding_id)?;
let shift = self.shift.forward(cond_embedding_id)?;
let xs = layer_norm(xs, self.eps)?;
xs * scale + shift
}
}
pub struct ConvNeXtBlock {
dwconv: Conv1d,
pwconv1: Linear,
pwconv2: Linear,
gamma: Option<Tensor>,
}
impl ConvNeXtBlock {
pub fn new(
dim: usize,
intermediate_dim: usize,
layer_scale_init_value: f64,
adanorm_num_embeddings: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let dwconv = {
let cfg = Conv1dConfig {
padding: 3,
groups: dim,
..Default::default()
};
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
};
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
let gamma = if layer_scale_init_value > 0. {
Some(vb.get(dim, "gamma")?)
} else {
None
};
Ok(Self {
dwconv,
pwconv1,
pwconv2,
gamma,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
// TODO: norm
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
let xs = match self.gamma.as_ref() {
Some(gamma) => (gamma * xs)?,
None => xs,
};
xs.transpose(1, 2)? + residual
}
}
struct VocosBackbone {
embed: Conv1d,
convnext: Vec<ConvNeXtBlock>,
final_layer_norm: candle_nn::LayerNorm,
}
impl VocosBackbone {
pub fn new(
input_channels: usize,
dim: usize,
intermediate_dim: usize,
num_layers: dim,
layer_scale_init_value: f64,
adanorm_num_embeddings: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let embed = {
let cfg = Conv1dConfig {
padding: 3,
..Default::default()
};
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
};
let mut convnext = Vec::with_capacity(num_layers);
let vb_c = vb.pp("convnext");
for i in 0..num_layers {
let block = ConvNeXtBlock::new(
dim,
intermediate_dim,
layer_scale_init_value,
adanorm_num_embeddings,
vb_c.pp(i),
)?;
}
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
Ok(Self {
embed,
convnext,
final_layer_norm,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embed)?;
// TODO: norm
let mut xs = xs.transpose(1, 2)?;
for conv_block in self.convnext.iter() {
xs = conv_block.forward(&xs)?
}
xs.apply(&self.final_layer_norm)
}
}

View File

@ -167,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
mel mel
} }
pub fn log_mel_spectrogram_<T: Float>( fn log_mel_spectrogram_<T: Float>(
samples: &[T], samples: &[T],
filters: &[T], filters: &[T],
fft_size: usize, fft_size: usize,

View File

@ -47,12 +47,6 @@ impl Linear {
} }
} }
pub fn linear_b(d1: usize, d2: usize, b: bool, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear_b(d1, d2, b, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> { pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear(d1, d2, vb)?; let inner = candle_nn::linear(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear"); let span = tracing::span!(tracing::Level::TRACE, "linear");