mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
3f3730b657 |
19
Cargo.toml
19
Cargo.toml
@ -19,7 +19,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.4.1"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -31,18 +31,17 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.1" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
|
25
README.md
25
README.md
@ -63,8 +63,6 @@ We also provide a some command line based examples using state of the art models
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
|
||||
Deepmind.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
@ -76,10 +74,9 @@ We also provide a some command line based examples using state of the art models
|
||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||
much faster inference.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/) and
|
||||
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
||||
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||
- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||
performance.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
@ -110,12 +107,7 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
||||
model using residual vector quantization.
|
||||
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
||||
text-to-speech.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
@ -195,10 +187,9 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Language Models.
|
||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
||||
- Falcon.
|
||||
- StarCoder, StarCoder2.
|
||||
- StarCoder.
|
||||
- Phi 1, 1.5, and 2.
|
||||
- Mamba, Minimal Mamba
|
||||
- Gemma 2b and 7b.
|
||||
- Mistral 7b v0.1.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||
@ -206,7 +197,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Qwen1.5.
|
||||
- RWKV v5 and v6.
|
||||
- RWKV.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
@ -216,22 +207,18 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- TrOCR.
|
||||
- Audio.
|
||||
- Whisper, multi-lingual speech-to-text.
|
||||
- EnCodec, audio compression model.
|
||||
- MetaVoice-1B, text-to-speech model.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
|
||||
ConvNeXTv2.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- SegFormer.
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
- Quantization support using the llama.cpp quantized types.
|
||||
|
@ -5,32 +5,25 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Module, Tensor};
|
||||
|
||||
use candle_core::quantized::{QMatMul, QTensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q = QMatMul::from_qtensor(q)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
||||
let res_q_cuda = q.forward(&x)?;
|
||||
println!("{res_q_cuda}");
|
||||
|
||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
||||
let res_q_cpu = q_cpu.forward(&x_cpu)?;
|
||||
println!("{res_q_cpu}");
|
||||
|
||||
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
|
||||
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?;
|
||||
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
println!("{out_t}");
|
||||
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||
.sqr()?
|
||||
.sum_all()?;
|
||||
println!("{diff}");
|
||||
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ impl Tensor {
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D { arg: node, .. }
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
@ -250,7 +250,6 @@ impl Tensor {
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
/* groups */ 1,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
@ -348,18 +347,9 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::UpsampleNearest1D { arg, target_size } => {
|
||||
let (_n, c, size) = arg.dims3()?;
|
||||
if target_size % size != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale = target_size / size;
|
||||
|
||||
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
|
@ -187,16 +187,36 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn conv_transpose1d_single_group(
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
params: &ParamsConvTranspose1D,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose1d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
params,
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||
arg,
|
||||
@ -210,49 +230,6 @@ impl Tensor {
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
if c_in % groups != 0 {
|
||||
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in: c_in / groups,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv_transpose1d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
|
@ -1263,7 +1263,6 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
@ -2575,7 +2574,7 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2584,7 +2583,7 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2601,7 +2600,7 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,15 +129,6 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Module> Module for Option<&M> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
None => Ok(xs.clone()),
|
||||
Some(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
|
@ -738,7 +738,6 @@ impl BackendStorage for MetalStorage {
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
@ -755,7 +754,6 @@ impl BackendStorage for MetalStorage {
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -829,9 +827,9 @@ impl BackendStorage for MetalStorage {
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1266,7 +1264,7 @@ impl BackendStorage for MetalStorage {
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let src_shape = src_l.shape();
|
||||
@ -1638,7 +1636,7 @@ impl BackendDevice for MetalDevice {
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&self.seed.lock().unwrap(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1669,7 +1667,7 @@ impl BackendDevice for MetalDevice {
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&self.seed.lock().unwrap(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
@ -132,10 +132,7 @@ pub enum Op {
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest1D {
|
||||
arg: Tensor,
|
||||
target_size: usize,
|
||||
},
|
||||
UpsampleNearest1D(Tensor),
|
||||
UpsampleNearest2D {
|
||||
arg: Tensor,
|
||||
target_h: usize,
|
||||
|
@ -42,7 +42,7 @@ pub enum OpCode {
|
||||
Stop = b'.',
|
||||
NewObj = 0x81,
|
||||
EmptyList = b']',
|
||||
BinFloat = b'G',
|
||||
BinFloat = b'g',
|
||||
Append = b'a',
|
||||
Appends = b'e',
|
||||
}
|
||||
@ -462,10 +462,7 @@ impl Stack {
|
||||
self.push(Object::Int(arg))
|
||||
}
|
||||
OpCode::BinFloat => {
|
||||
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
|
||||
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
|
||||
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
|
||||
let arg = r.read_f64::<byteorder::BigEndian>()?;
|
||||
let arg = r.read_f64::<LittleEndian>()?;
|
||||
self.push(Object::Float(arg))
|
||||
}
|
||||
OpCode::BinUnicode => {
|
||||
|
@ -1,343 +0,0 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
|
||||
use cudarc::driver::{CudaSlice, DeviceSlice};
|
||||
|
||||
pub struct QCudaStorage {
|
||||
data: CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
pub const WARP_SIZE: usize = 32;
|
||||
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
||||
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
|
||||
fn dequantize(
|
||||
data: &CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||
GgmlDType::Q5_0 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_0",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_1",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let params = (data, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (data, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_mut_mal_vec(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &cudarc::driver::CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
||||
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
|
||||
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
|
||||
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
|
||||
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
|
||||
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
|
||||
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
|
||||
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
|
||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
|
||||
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (block_num_y as u32, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (data, y, &dst, ncols as i32, nrows as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
|
||||
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
||||
Ok(QCudaStorage {
|
||||
data,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
let fast_kernel = matches!(
|
||||
self.dtype,
|
||||
GgmlDType::Q4_0
|
||||
| GgmlDType::Q4_1
|
||||
| GgmlDType::Q5_0
|
||||
| GgmlDType::Q5_1
|
||||
| GgmlDType::Q8_0
|
||||
| GgmlDType::Q2K
|
||||
| GgmlDType::Q3K
|
||||
| GgmlDType::Q4K
|
||||
| GgmlDType::Q5K
|
||||
| GgmlDType::Q6K
|
||||
| GgmlDType::Q8K
|
||||
);
|
||||
if fast_kernel {
|
||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||
}
|
||||
// Run the dequantization on cpu.
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
|
||||
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let slice =
|
||||
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
|
||||
out.copy_from_slice(slice)
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
self.device
|
||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||
self.device.dtoh_sync_copy(data).w()?
|
||||
}
|
||||
_ => crate::bail!("only f32 can be quantized"),
|
||||
};
|
||||
let src_len = src.len();
|
||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
||||
qcpu_storage.quantize(&src)?;
|
||||
let data = qcpu_storage.data()?;
|
||||
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
|
||||
self.data = data;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||
} else {
|
||||
self.dequantize_matmul(self_shape, storage, layout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
fn dequantize_matmul_vec(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
rhs: &CudaStorage,
|
||||
rhs_l: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
let (nrows, ncols) = self_shape.dims2()?;
|
||||
let rhs = rhs.as_cuda_slice::<f32>()?;
|
||||
let rhs = match rhs_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||
};
|
||||
let (with_batch, k) = match rhs_l.shape().dims() {
|
||||
[1, 1, k] => (true, k),
|
||||
[1, k] => (false, k),
|
||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||
};
|
||||
if ncols != *k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||
}
|
||||
|
||||
let out =
|
||||
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
|
||||
let out_shape = if with_batch {
|
||||
vec![1, 1, nrows]
|
||||
} else {
|
||||
vec![1, nrows]
|
||||
};
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
|
||||
fn dequantize_matmul(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
use crate::backend::BackendStorage;
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let (b, m, k2) = match layout.shape().dims() {
|
||||
&[b, m, k2] => (b, m, k2),
|
||||
&[m, k2] => (1, m, k2),
|
||||
s => crate::bail!("unexpected shape for input {s:?}"),
|
||||
};
|
||||
if k2 != k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
||||
}
|
||||
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0);
|
||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
||||
let mut out_shape = layout.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(n);
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &CudaDevice,
|
||||
data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
let data = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
|
||||
};
|
||||
let data = device.htod_sync_copy(data).w()?;
|
||||
Ok(QStorage::Cuda(QCudaStorage {
|
||||
data,
|
||||
device: device.clone(),
|
||||
dtype: T::DTYPE,
|
||||
}))
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
#![allow(unused)]
|
||||
use super::GgmlDType;
|
||||
use crate::{CudaDevice, CudaStorage, Error, Result};
|
||||
|
||||
pub struct QCudaStorage {
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
_self_shape: &crate::Shape,
|
||||
_storage: &CudaStorage,
|
||||
_layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
_device: &CudaDevice,
|
||||
_data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
@ -41,10 +41,3 @@ impl QMetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
_device: &MetalDevice,
|
||||
_data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use super::metal::load_quantized_metal;
|
||||
use super::{k_quants, GgmlDType, QStorage};
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
@ -128,8 +130,13 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
let data: QStorage = match device {
|
||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
|
||||
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal backend requires `metal` feature")
|
||||
}
|
||||
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
||||
};
|
||||
super::QTensor::new(data, dims)
|
||||
}
|
||||
|
@ -34,8 +34,6 @@ impl QMetalStorage {
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
@ -45,62 +43,81 @@ impl QMetalStorage {
|
||||
blit.end_encoding();
|
||||
self.device.wait_until_completed()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
f32::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ2K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ3K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ4K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ5K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ6K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ8K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
@ -175,7 +192,7 @@ impl QMetalStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &MetalDevice,
|
||||
data: &[T],
|
||||
) -> Result<QStorage> {
|
||||
|
@ -4,7 +4,6 @@ use std::borrow::Cow;
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
mod dummy_cuda;
|
||||
mod dummy_metal;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
@ -15,13 +14,6 @@ pub mod metal;
|
||||
mod metal {
|
||||
pub use super::dummy_metal::*;
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod cuda;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
mod cuda {
|
||||
pub use super::dummy_cuda::*;
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
@ -47,9 +39,8 @@ impl Device {
|
||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
||||
Ok(QStorage::Metal(storage))
|
||||
}
|
||||
Device::Cuda(cuda) => {
|
||||
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
|
||||
Ok(QStorage::Cuda(storage))
|
||||
Device::Cuda(_cuda) => {
|
||||
crate::bail!("Cuda ggml quantization not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -58,7 +49,6 @@ impl Device {
|
||||
pub enum QStorage {
|
||||
Cpu(Box<dyn QuantizedType>),
|
||||
Metal(metal::QMetalStorage),
|
||||
Cuda(cuda::QCudaStorage),
|
||||
}
|
||||
|
||||
impl QStorage {
|
||||
@ -66,7 +56,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.block_size(),
|
||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||
QStorage::Cuda(storage) => storage.dtype().block_size(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,7 +63,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.dtype(),
|
||||
QStorage::Metal(storage) => storage.dtype(),
|
||||
QStorage::Cuda(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,7 +70,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(_storage) => Device::Cpu,
|
||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,7 +77,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -100,7 +86,6 @@ impl QStorage {
|
||||
storage.from_float(src.as_slice::<f32>()?)?;
|
||||
}
|
||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
|
||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||
}
|
||||
Ok(())
|
||||
@ -110,7 +95,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,7 +106,7 @@ impl QStorage {
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
Ok(Cow::from(data))
|
||||
}
|
||||
QStorage::Metal(_) | QStorage::Cuda(_) => {
|
||||
QStorage::Metal(_storage) => {
|
||||
crate::bail!("not implemented");
|
||||
}
|
||||
}
|
||||
@ -440,7 +424,7 @@ impl crate::CustomOp1 for QTensor {
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cpu(storage) => storage,
|
||||
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
|
||||
QStorage::Metal(_) => crate::bail!("Invalid storage"),
|
||||
};
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
@ -460,18 +444,6 @@ impl crate::CustomOp1 for QTensor {
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &crate::CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CudaStorage, Shape)> {
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cuda(cuda) => cuda,
|
||||
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for QMatMul {
|
||||
|
@ -352,10 +352,6 @@ impl Storage {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
|
@ -1015,7 +1015,7 @@ impl Tensor {
|
||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||
let (n, c, _l) = self.dims3()?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest1d(self.layout(), target_size)?;
|
||||
|
@ -18,9 +18,6 @@ w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -53,7 +50,7 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -62,17 +59,6 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 4, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||
[
|
||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -283,38 +283,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||
);
|
||||
|
||||
if device.is_cpu() {
|
||||
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
||||
let y = x.interpolate1d(12)?.reshape(36)?;
|
||||
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
|
||||
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
|
||||
33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
let grads = loss.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(grad_x, 4)?,
|
||||
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
||||
);
|
||||
}
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||
35., 36.,
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
@ -345,11 +326,15 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||
35., 36.,
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
@ -178,6 +178,10 @@ test_device!(
|
||||
);
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
@ -205,6 +209,10 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
@ -231,6 +239,10 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
@ -257,6 +269,10 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
@ -357,6 +373,10 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
||||
}
|
||||
|
||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q2K;
|
||||
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
@ -391,6 +411,10 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q3K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -424,6 +448,10 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q4K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -457,6 +485,10 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q5K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -490,6 +522,10 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q6K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -523,6 +559,10 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q8K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -738,6 +778,10 @@ macro_rules! quantized_matmul {
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||
fn $fn_name(device: &Device) -> Result<()> {
|
||||
if device.is_cuda() {
|
||||
// TODO Enable Cuda GGML sometime maybe.
|
||||
return Ok(());
|
||||
}
|
||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true, optional = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
@ -30,7 +30,7 @@ rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"] }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
cpal= { version = "0.15.2", optional = true }
|
||||
|
||||
@ -80,26 +80,6 @@ required-features = ["onnx"]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper-microphone"
|
||||
required-features = ["microphone"]
|
||||
|
||||
[[example]]
|
||||
name = "mnist-training"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "llama2-c"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
[[example]]
|
||||
name = "metavoice"
|
||||
required-features = ["symphonia"]
|
||||
|
@ -1 +0,0 @@
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
|
||||
|
@ -1,20 +0,0 @@
|
||||
# candle-efficientvit
|
||||
|
||||
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://arxiv.org/abs/2305.07027).
|
||||
|
||||
This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
|
||||
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example efficientvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 69.80%
|
||||
unicycle, monocycle : 13.03%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 9.28%
|
||||
crash helmet : 2.25%
|
||||
alp : 0.46%
|
||||
```
|
@ -1,99 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::efficientvit;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
M0,
|
||||
M1,
|
||||
M2,
|
||||
M3,
|
||||
M4,
|
||||
M5,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::M0 => "m0",
|
||||
Self::M1 => "m1",
|
||||
Self::M2 => "m2",
|
||||
Self::M3 => "m3",
|
||||
Self::M4 => "m4",
|
||||
Self::M5 => "m5",
|
||||
};
|
||||
format!("timm/efficientvit_{}.r224_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> efficientvit::Config {
|
||||
match self {
|
||||
Self::M0 => efficientvit::Config::m0(),
|
||||
Self::M1 => efficientvit::Config::m1(),
|
||||
Self::M2 => efficientvit::Config::m2(),
|
||||
Self::M3 => efficientvit::Config::m3(),
|
||||
Self::M4 => efficientvit::Config::m4(),
|
||||
Self::M5 => efficientvit::Config::m5(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::M0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = efficientvit::efficientvit(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
# candle-endocec
|
||||
|
||||
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
|
||||
compression model using an encoder/decoder architecture with residual vector
|
||||
quantization.
|
||||
|
||||
## Running one example
|
||||
|
||||
```bash
|
||||
cargo run --example encodec --features symphonia --release -- code-to-audio \
|
||||
candle-examples/examples/encodec/jfk-codes.safetensors \
|
||||
jfk.wav
|
||||
```
|
||||
|
||||
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
|
||||
an output wav file containing the audio data. Instead of `code-to-audio` one
|
||||
can use:
|
||||
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
|
||||
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
|
||||
containing EnCodec tokens for the input audio file.
|
Binary file not shown.
@ -1,143 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::encodec::{Config, Model};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Action {
|
||||
AudioToAudio,
|
||||
AudioToCode,
|
||||
CodeToAudio,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The action to be performed, specifies the format for the input and output data.
|
||||
action: Action,
|
||||
|
||||
/// The input file, either an audio file or some encodec tokens stored as safetensors.
|
||||
in_file: String,
|
||||
|
||||
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
|
||||
out_file: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.model("facebook/encodec_24khz".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let config = Config::default();
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
||||
codes
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
|
||||
if sample_rate != 24_000 {
|
||||
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
|
||||
}
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
}
|
||||
};
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
|
||||
match args.action {
|
||||
Action::AudioToCode => {
|
||||
codes.save_safetensors("codes", &args.out_file)?;
|
||||
}
|
||||
Action::AudioToAudio | Action::CodeToAudio => {
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
# candle-mistral: 2b and 7b LLMs from Google DeepMind
|
||||
|
||||
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
||||
models published by Google Deepmind with a 2b and a 7b variant.
|
||||
|
||||
In order to use the example below, you have to accept the license on the
|
||||
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
||||
your access token via the [HuggingFace cli login
|
||||
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
|
||||
fn count_primes(max_n: usize) -> usize {
|
||||
let mut primes = vec![true; max_n];
|
||||
for i in 2..=max_n {
|
||||
if primes[i] {
|
||||
for j in i * i..max_n {
|
||||
primes[j] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
primes.len()
|
||||
}
|
||||
```
|
||||
|
@ -1,256 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::gemma::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => match model_id.as_str() {
|
||||
"7b-it" => "google/gemma-7b-it".to_string(),
|
||||
"7b" => "google/gemma-7b".to_string(),
|
||||
"2b-it" => "google/gemma-2b-it".to_string(),
|
||||
"2b" => "google/gemma-2b".to_string(),
|
||||
_ => model_id.to_string(),
|
||||
},
|
||||
None => "google/gemma-2b".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -57,7 +57,7 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 10000)]
|
||||
#[arg(long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Disable the key-value cache.
|
||||
@ -120,7 +120,7 @@ fn main() -> Result<()> {
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, mut cache) = {
|
||||
let (llama, tokenizer_filename, cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
@ -143,10 +143,11 @@ fn main() -> Result<()> {
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
};
|
||||
println!("building the model");
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
@ -156,7 +157,6 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
@ -172,7 +172,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, context_index, &mut cache)?;
|
||||
let logits = llama.forward(&input, context_index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
@ -190,16 +190,18 @@ fn main() -> Result<()> {
|
||||
token_generated += 1;
|
||||
tokens.push(next_token);
|
||||
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
if Some(next_token) == eos_token_id {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
|
@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Cache, Config, Llama};
|
||||
use model::{Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
@ -160,10 +160,10 @@ enum Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, config)?;
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
let tokens = match &args.pretokenized_dir {
|
||||
None => {
|
||||
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
for inp_tgt in batch_iter {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
println!("{}", loss.to_vec0::<f32>()?);
|
||||
}
|
||||
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config, mut cache) = if is_gguf {
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
&device,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny_15m();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
};
|
||||
|
||||
println!("starting the inference loop");
|
||||
@ -328,7 +328,6 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0.. {
|
||||
@ -338,7 +337,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos, &mut cache)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
logits
|
||||
@ -354,14 +353,16 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
// Extracting the last token as a string is complicated, here we just apply some simple
|
||||
// heuristics as it seems to work well enough for this example. See the following for more
|
||||
// details:
|
||||
// https://github.com/huggingface/tokenizers/issues/1141#issuecomment-1562644141
|
||||
if let Some(text) = tokenizer.id_to_token(next_token) {
|
||||
let text = text.replace('▁', " ").replace("<0x0A>", "\n");
|
||||
print!("{text}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n{} tokens generated ({:.2} token/s)\n",
|
||||
|
@ -8,7 +8,6 @@ fn valid_loss(
|
||||
model: &Llama,
|
||||
args: &crate::TrainingCmd,
|
||||
device: &Device,
|
||||
cache: &mut Cache,
|
||||
) -> Result<f64> {
|
||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
@ -16,7 +15,7 @@ fn valid_loss(
|
||||
let mut cnt = 0usize;
|
||||
for inp_tgt in batch_iter.take(50) {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0, cache)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
sum_ce += loss.to_vec0::<f32>()? as f64;
|
||||
cnt += 1;
|
||||
@ -38,8 +37,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, config)?;
|
||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let params = candle_nn::ParamsAdamW {
|
||||
lr: args.learning_rate,
|
||||
..Default::default()
|
||||
@ -47,14 +46,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
||||
for (batch_index, batch) in batch_iter.enumerate() {
|
||||
let (inp, tgt) = batch?;
|
||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
opt.backward_step(&loss)?;
|
||||
|
||||
if batch_index > 0 && batch_index % 100 == 0 {
|
||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||
// validation loss.
|
||||
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
|
||||
let loss = valid_loss(&dataset, &model, args, &device)?;
|
||||
println!("{batch_index} {loss}");
|
||||
}
|
||||
if batch_index > 0 && batch_index % 1000 == 0 {
|
||||
|
@ -1,18 +0,0 @@
|
||||
# candle-metavoice
|
||||
|
||||
MetaVoice-1B is a text-to-speech model trained on 100K hours of speech, more
|
||||
details on the [model
|
||||
card](https://huggingface.co/metavoiceio/metavoice-1B-v0.1).
|
||||
|
||||
Note that the current candle implementation suffers from some limitations as of
|
||||
2024-03-02:
|
||||
- The speaker embeddings are hardcoded.
|
||||
- The generated audio file quality is weaker than the Python implementation,
|
||||
probably because of some implementation discrepancies.
|
||||
|
||||
## Run an example
|
||||
|
||||
```bash
|
||||
cargo run --example metavoice --release -- \\
|
||||
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
|
||||
```
|
@ -1,342 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::encodec;
|
||||
use candle_transformers::models::metavoice::{
|
||||
adapters, gpt, speaker_encoder, tokenizers, transformer,
|
||||
};
|
||||
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use hf_hub::api::sync::Api;
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
pub const ENCODEC_NTOKENS: u32 = 1024;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum ArgDType {
|
||||
F32,
|
||||
F16,
|
||||
Bf16,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The guidance scale.
|
||||
#[arg(long, default_value_t = 3.0)]
|
||||
guidance_scale: f64,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
temperature: f64,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The maximum number of tokens to generate for the first stage.
|
||||
#[arg(long, default_value_t = 2000)]
|
||||
max_tokens: u64,
|
||||
|
||||
/// The output file using the wav format.
|
||||
#[arg(long, default_value = "out.wav")]
|
||||
out_file: String,
|
||||
|
||||
#[arg(long)]
|
||||
first_stage_meta: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
first_stage_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
second_stage_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
speaker_encoder_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
encodec_weights: Option<String>,
|
||||
|
||||
/// The speaker embeddings, either an audio files in which case they are extracted, or a
|
||||
/// safetensors file with the embeddings already extracted.
|
||||
#[arg(long)]
|
||||
spk_emb: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "f32")]
|
||||
dtype: ArgDType,
|
||||
}
|
||||
|
||||
fn mel_filters() -> Result<Vec<f32>> {
|
||||
let mel_bytes = include_bytes!("melfilters40.bytes").as_slice();
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
Ok(mel_filters)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let api = Api::new()?;
|
||||
let repo = api.model("lmz/candle-metavoice".to_string());
|
||||
let first_stage_meta = match &args.first_stage_meta {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.meta.json")?,
|
||||
};
|
||||
let first_stage_meta: serde_json::Value =
|
||||
serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?;
|
||||
let first_stage_tokenizer = match first_stage_meta.as_object() {
|
||||
None => anyhow::bail!("not a json object"),
|
||||
Some(j) => match j.get("tokenizer") {
|
||||
None => anyhow::bail!("no tokenizer key"),
|
||||
Some(j) => j,
|
||||
},
|
||||
};
|
||||
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
|
||||
|
||||
let first_stage_weights = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.safetensors")?,
|
||||
};
|
||||
let second_stage_weights = match &args.second_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("second_stage.safetensors")?,
|
||||
};
|
||||
let encodec_weights = match args.encodec_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => Api::new()?
|
||||
.model("facebook/encodec_24khz".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let dtype = match args.dtype {
|
||||
ArgDType::F32 => DType::F32,
|
||||
ArgDType::F16 => DType::F16,
|
||||
ArgDType::Bf16 => DType::BF16,
|
||||
};
|
||||
let first_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
||||
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
||||
let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
||||
|
||||
let second_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
|
||||
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
||||
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
|
||||
|
||||
let encodec_device = if device.is_metal() {
|
||||
&candle::Device::Cpu
|
||||
} else {
|
||||
&device
|
||||
};
|
||||
let encodec_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
|
||||
let encodec_config = encodec::Config::default();
|
||||
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
|
||||
|
||||
println!("prompt: '{}'", args.prompt);
|
||||
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
|
||||
let mut tokens = prompt_tokens.clone();
|
||||
println!("{tokens:?}");
|
||||
let safetensors_embeddings = args
|
||||
.spk_emb
|
||||
.as_ref()
|
||||
.map_or(true, |v| v.ends_with("safetensors"));
|
||||
let spk_emb = if safetensors_embeddings {
|
||||
let spk_emb_file = match &args.spk_emb {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("spk_emb.safetensors")?,
|
||||
};
|
||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
||||
match spk_emb.get("spk_emb") {
|
||||
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
||||
Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?,
|
||||
}
|
||||
} else {
|
||||
let weights = match &args.speaker_encoder_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("speaker_encoder.safetensors")?,
|
||||
};
|
||||
let mel_filters = mel_filters()?;
|
||||
let config = speaker_encoder::Config::cfg();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? };
|
||||
let model = speaker_encoder::Model::new(config, vb)?;
|
||||
let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?;
|
||||
if sample_rate != 16_000 {
|
||||
eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!")
|
||||
}
|
||||
model.embed_utterance(
|
||||
&pcm,
|
||||
&mel_filters,
|
||||
/* rate */ 1.3,
|
||||
/* min_c */ 0.75,
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
||||
|
||||
// First stage generation.
|
||||
for index in 0..args.max_tokens {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let input = Tensor::stack(&[&input, &input], 0)?;
|
||||
let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
|
||||
let logits0 = logits.i((0, 0))?;
|
||||
let logits1 = logits.i((1, 0))?;
|
||||
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
print!(".");
|
||||
std::io::stdout().flush()?;
|
||||
if next_token == 2048 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
println!();
|
||||
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
|
||||
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
|
||||
println!("text ids len: {}", text_ids.len());
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
|
||||
// TODO: Use the config rather than hardcoding the offset here.
|
||||
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
|
||||
let mut hierarchies_in1 =
|
||||
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
|
||||
let mut hierarchies_in2 = [
|
||||
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
|
||||
ids2.as_slice(),
|
||||
&[ENCODEC_NTOKENS],
|
||||
]
|
||||
.concat();
|
||||
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
|
||||
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
||||
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
||||
let logits = second_stage_model.forward(&in_x)?;
|
||||
println!("sampling from logits...");
|
||||
let mut codes = vec![];
|
||||
for logits in logits.iter() {
|
||||
let logits = logits.squeeze(0)?;
|
||||
let (seq_len, _) = logits.dims2()?;
|
||||
let mut codes_ = Vec::with_capacity(seq_len);
|
||||
for step in 0..seq_len {
|
||||
let logits = logits.i(step)?.to_dtype(DType::F32)?;
|
||||
let logits = &(&logits / 1.0)?;
|
||||
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
|
||||
let sample = distr.sample(&mut rng) as u32;
|
||||
codes_.push(sample)
|
||||
}
|
||||
codes.push(codes_)
|
||||
}
|
||||
|
||||
let codes = Tensor::new(codes, &device)?.unsqueeze(0)?;
|
||||
let codes = Tensor::cat(&[in_x, codes], 1)?;
|
||||
println!("codes: {codes}");
|
||||
let tilted_encodec = adapters::TiltedEncodec::new(ENCODEC_NTOKENS);
|
||||
let codes = codes.i(0)?.to_vec2::<u32>()?;
|
||||
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
|
||||
println!("text_ids len: {:?}", text_ids.len());
|
||||
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
|
||||
println!("audio_ids shape: {:?}", audio_ids.shape());
|
||||
let pcm = encodec_model.decode(&audio_ids)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
Ok(())
|
||||
}
|
Binary file not shown.
@ -152,7 +152,7 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
|
@ -143,7 +143,7 @@ struct Args {
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
#[arg(long, short = 'n', default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
|
||||
|
580
candle-examples/examples/musicgen/encodec_model.rs
Normal file
580
candle-examples/examples/musicgen/encodec_model.rs
Normal file
@ -0,0 +1,580 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NormType {
|
||||
WeightNorm,
|
||||
TimeGroupNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
target_bandwidths: Vec<f64>,
|
||||
sampling_rate: usize,
|
||||
audio_channels: usize,
|
||||
normalize: bool,
|
||||
chunk_length_s: Option<usize>,
|
||||
overlap: Option<usize>,
|
||||
hidden_size: usize,
|
||||
num_filters: usize,
|
||||
num_residual_layers: usize,
|
||||
upsampling_ratios: Vec<usize>,
|
||||
norm_type: NormType,
|
||||
kernel_size: usize,
|
||||
last_kernel_size: usize,
|
||||
residual_kernel_size: usize,
|
||||
dilation_growth_rate: usize,
|
||||
use_causal_conv: bool,
|
||||
pad_mode: &'static str,
|
||||
compress: usize,
|
||||
num_lstm_layers: usize,
|
||||
trim_right_ratio: f64,
|
||||
codebook_size: usize,
|
||||
codebook_dim: Option<usize>,
|
||||
use_conv_shortcut: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
|
||||
sampling_rate: 24_000,
|
||||
audio_channels: 1,
|
||||
normalize: false,
|
||||
chunk_length_s: None,
|
||||
overlap: None,
|
||||
hidden_size: 128,
|
||||
num_filters: 32,
|
||||
num_residual_layers: 1,
|
||||
upsampling_ratios: vec![8, 5, 4, 2],
|
||||
norm_type: NormType::WeightNorm,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
residual_kernel_size: 3,
|
||||
dilation_growth_rate: 2,
|
||||
use_causal_conv: true,
|
||||
pad_mode: "reflect",
|
||||
compress: 2,
|
||||
num_lstm_layers: 2,
|
||||
trim_right_ratio: 1.0,
|
||||
codebook_size: 1024,
|
||||
codebook_dim: None,
|
||||
use_conv_shortcut: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||
pub fn musicgen_small() -> Self {
|
||||
Self {
|
||||
audio_channels: 1,
|
||||
chunk_length_s: None,
|
||||
codebook_dim: Some(128),
|
||||
codebook_size: 2048,
|
||||
compress: 2,
|
||||
dilation_growth_rate: 2,
|
||||
hidden_size: 128,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
norm_type: NormType::WeightNorm,
|
||||
normalize: false,
|
||||
num_filters: 64,
|
||||
num_lstm_layers: 2,
|
||||
num_residual_layers: 1,
|
||||
overlap: None,
|
||||
pad_mode: "reflect",
|
||||
residual_kernel_size: 3,
|
||||
sampling_rate: 32_000,
|
||||
target_bandwidths: vec![2.2],
|
||||
trim_right_ratio: 1.0,
|
||||
upsampling_ratios: vec![8, 5, 4, 4],
|
||||
use_causal_conv: false,
|
||||
use_conv_shortcut: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn codebook_dim(&self) -> usize {
|
||||
self.codebook_dim.unwrap_or(self.codebook_size)
|
||||
}
|
||||
|
||||
fn frame_rate(&self) -> usize {
|
||||
let hop_length: usize = self.upsampling_ratios.iter().product();
|
||||
(self.sampling_rate + hop_length - 1) / hop_length
|
||||
}
|
||||
|
||||
fn num_quantizers(&self) -> usize {
|
||||
let num = 1000f64
|
||||
* self
|
||||
.target_bandwidths
|
||||
.last()
|
||||
.expect("empty target_bandwidths");
|
||||
(num as usize) / (self.frame_rate() * 10)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
|
||||
#[derive(Debug)]
|
||||
struct EncodecEuclideanCodebook {
|
||||
inited: Tensor,
|
||||
cluster_size: Tensor,
|
||||
embed: Tensor,
|
||||
embed_avg: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecEuclideanCodebook {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inited = vb.get(1, "inited")?;
|
||||
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||
let embed = vb.get(e_shape, "embed")?;
|
||||
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||
Ok(Self {
|
||||
inited,
|
||||
cluster_size,
|
||||
embed,
|
||||
embed_avg,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.embed.embedding(embed_ind)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecVectorQuantization {
|
||||
codebook: EncodecEuclideanCodebook,
|
||||
}
|
||||
|
||||
impl EncodecVectorQuantization {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
||||
Ok(Self { codebook })
|
||||
}
|
||||
|
||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.codebook.decode(embed_ind)?;
|
||||
let quantize = quantize.transpose(1, 2)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResidualVectorQuantizer {
|
||||
layers: Vec<EncodecVectorQuantization>,
|
||||
}
|
||||
|
||||
impl EncodecResidualVectorQuantizer {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("layers");
|
||||
let layers = (0..cfg.num_quantizers())
|
||||
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||
if codes.dim(0)? != self.layers.len() {
|
||||
candle::bail!(
|
||||
"codes shape {:?} does not match the number of quantization layers {}",
|
||||
codes.shape(),
|
||||
self.layers.len()
|
||||
)
|
||||
}
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let quantized = layer.decode(&codes.i(i)?)?;
|
||||
quantized_out = quantized.broadcast_add(&quantized_out)?;
|
||||
}
|
||||
Ok(quantized_out)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConvTranspose1d {
|
||||
weight_g: Tensor,
|
||||
weight_v: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecConvTranspose1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k: usize,
|
||||
_stride: usize,
|
||||
vb: VarBuilder,
|
||||
_cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let vb = &vb.pp("conv");
|
||||
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Self {
|
||||
weight_g,
|
||||
weight_v,
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConv1d {
|
||||
causal: bool,
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let conv = match cfg.norm_type {
|
||||
NormType::WeightNorm => conv1d_weight_norm(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
let norm = match cfg.norm_type {
|
||||
NormType::None | NormType::WeightNorm => None,
|
||||
NormType::TimeGroupNorm => {
|
||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(gn)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
causal: cfg.use_causal_conv,
|
||||
conv,
|
||||
norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
match &self.norm {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResnetBlock {
|
||||
block_conv1: EncodecConv1d,
|
||||
block_conv2: EncodecConv1d,
|
||||
shortcut: Option<EncodecConv1d>,
|
||||
}
|
||||
|
||||
impl EncodecResnetBlock {
|
||||
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new(vb.pp("block"));
|
||||
if dilations.len() != 2 {
|
||||
candle::bail!("expected dilations of size 2")
|
||||
}
|
||||
// TODO: Apply dilations!
|
||||
layer.inc();
|
||||
let block_conv1 =
|
||||
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
|
||||
layer.inc();
|
||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
|
||||
let shortcut = if cfg.use_conv_shortcut {
|
||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
block_conv1,
|
||||
block_conv2,
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv1.forward(&xs)?;
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv2.forward(&xs)?;
|
||||
let xs = match &self.shortcut {
|
||||
None => (xs + residual)?,
|
||||
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
|
||||
};
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Layer<'a> {
|
||||
vb: VarBuilder<'a>,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl<'a> Layer<'a> {
|
||||
fn new(vb: VarBuilder<'a>) -> Self {
|
||||
Self { vb, cnt: 0 }
|
||||
}
|
||||
|
||||
fn inc(&mut self) {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next(&mut self) -> VarBuilder {
|
||||
let vb = self.vb.pp(&self.cnt.to_string());
|
||||
self.cnt += 1;
|
||||
vb
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecEncoder {
|
||||
init_conv: EncodecConv1d,
|
||||
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
|
||||
final_lstm: EncodecLSTM,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.audio_channels,
|
||||
cfg.num_filters,
|
||||
cfg.kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut sampling_layers = vec![];
|
||||
let mut scaling = 1;
|
||||
for &ratio in cfg.upsampling_ratios.iter().rev() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConv1d::load(
|
||||
current_scale,
|
||||
current_scale * 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
sampling_layers.push((resnets, conv1d));
|
||||
scaling *= 2;
|
||||
}
|
||||
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters * scaling,
|
||||
cfg.hidden_size,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
final_lstm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecDecoder {
|
||||
init_conv: EncodecConv1d,
|
||||
init_lstm: EncodecLSTM,
|
||||
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.hidden_size,
|
||||
cfg.num_filters * scaling,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
let mut sampling_layers = vec![];
|
||||
for &ratio in cfg.upsampling_ratios.iter() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConvTranspose1d::load(
|
||||
current_scale,
|
||||
current_scale / 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale / 2,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
sampling_layers.push((conv1d, resnets));
|
||||
scaling /= 2;
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters,
|
||||
cfg.audio_channels,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
init_lstm,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EncodecModel {
|
||||
encoder: EncodecEncoder,
|
||||
decoder: EncodecDecoder,
|
||||
quantizer: EncodecResidualVectorQuantizer,
|
||||
}
|
||||
|
||||
impl EncodecModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
|
||||
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
|
||||
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quantizer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
@ -10,7 +10,9 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod encodec_model;
|
||||
mod musicgen_model;
|
||||
mod nn;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
use crate::encodec_model;
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||
VarBuilder,
|
||||
};
|
||||
use candle_transformers::models::{encodec, t5};
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -371,7 +372,7 @@ impl MusicgenForCausalLM {
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForConditionalGeneration {
|
||||
pub text_encoder: t5::T5EncoderModel,
|
||||
pub audio_encoder: encodec::Model,
|
||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||
pub decoder: MusicgenForCausalLM,
|
||||
cfg: GenConfig,
|
||||
}
|
||||
@ -380,42 +381,15 @@ pub struct MusicgenForConditionalGeneration {
|
||||
pub struct GenConfig {
|
||||
musicgen: Config,
|
||||
t5: t5::Config,
|
||||
encodec: encodec::Config,
|
||||
encodec: crate::encodec_model::Config,
|
||||
}
|
||||
|
||||
impl GenConfig {
|
||||
pub fn small() -> Self {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||
let encodec = encodec::Config {
|
||||
audio_channels: 1,
|
||||
chunk_length_s: None,
|
||||
codebook_dim: Some(128),
|
||||
codebook_size: 2048,
|
||||
compress: 2,
|
||||
dilation_growth_rate: 2,
|
||||
hidden_size: 128,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
norm_type: encodec::NormType::WeightNorm,
|
||||
normalize: false,
|
||||
num_filters: 64,
|
||||
num_lstm_layers: 2,
|
||||
num_residual_layers: 1,
|
||||
overlap: None,
|
||||
// This should be Reflect and not Replicate but Reflect does not work yet.
|
||||
pad_mode: encodec::PadMode::Replicate,
|
||||
residual_kernel_size: 3,
|
||||
sampling_rate: 32_000,
|
||||
target_bandwidths: vec![2.2],
|
||||
trim_right_ratio: 1.0,
|
||||
upsampling_ratios: vec![8, 5, 4, 4],
|
||||
use_causal_conv: false,
|
||||
use_conv_shortcut: false,
|
||||
};
|
||||
Self {
|
||||
musicgen: Config::musicgen_small(),
|
||||
t5: t5::Config::musicgen_small(),
|
||||
encodec,
|
||||
encodec: encodec_model::Config::musicgen_small(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -427,7 +401,8 @@ impl MusicgenForConditionalGeneration {
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?;
|
||||
let audio_encoder =
|
||||
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||
Ok(Self {
|
||||
text_encoder,
|
||||
|
20
candle-examples/examples/musicgen/nn.rs
Normal file
20
candle-examples/examples/musicgen/nn.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use candle::Result;
|
||||
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
pub fn conv1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
@ -212,14 +212,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -369,7 +361,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let device = candle_examples::device(false)?;
|
||||
|
||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||
Some("gguf") => {
|
||||
@ -495,20 +487,11 @@ fn main() -> anyhow::Result<()> {
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in prompt_tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
|
||||
with performance on par with transformer architectures. Several variants are
|
||||
available, candle implements the v5 and v6 versions and can be used with
|
||||
Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||
available, candle implements the v5 version and can be used with Eagle 7B([blog
|
||||
post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||
|
||||
```bash
|
||||
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "
|
||||
|
@ -7,36 +7,13 @@ extern crate accelerate_src;
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::quantized_rwkv_v5::Model as Q5;
|
||||
use candle_transformers::models::quantized_rwkv_v6::Model as Q6;
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer};
|
||||
use candle_transformers::models::rwkv_v6::Model as M6;
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
const EOS_TOKEN_ID: u32 = 261;
|
||||
|
||||
enum Model {
|
||||
M5(M5),
|
||||
Q5(Q5),
|
||||
M6(M6),
|
||||
Q6(Q6),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::M5(m) => m.forward(xs, state),
|
||||
Self::Q5(m) => m.forward(xs, state),
|
||||
Self::M6(m) => m.forward(xs, state),
|
||||
Self::Q6(m) => m.forward(xs, state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
@ -106,9 +83,6 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == EOS_TOKEN_ID || next_token == 0 {
|
||||
break;
|
||||
}
|
||||
print!("{}", self.tokenizer.decode(&[next_token])?);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
@ -129,7 +103,6 @@ enum Which {
|
||||
Eagle7b,
|
||||
World1b5,
|
||||
World3b,
|
||||
World6_1b6,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
@ -144,7 +117,6 @@ impl Which {
|
||||
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||
Self::World6_1b6 => "paperfun/rwkv",
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,7 +124,6 @@ impl Which {
|
||||
match self {
|
||||
Self::Eagle7b => "refs/pr/1",
|
||||
Self::World1b5 | Self::World3b => "refs/pr/2",
|
||||
Self::World6_1b6 => "main",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -205,9 +176,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -268,27 +236,7 @@ fn main() -> Result<()> {
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![match args.which {
|
||||
Which::World1b5 => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("world1b5-q4k.gguf")?,
|
||||
Which::World3b => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("world3b-q4k.gguf")?,
|
||||
Which::Eagle7b => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("eagle7b-q4k.gguf")?,
|
||||
Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?,
|
||||
}]
|
||||
} else {
|
||||
vec![match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => {
|
||||
repo.get("model.safetensors")?
|
||||
}
|
||||
Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?,
|
||||
}]
|
||||
}
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
@ -297,21 +245,8 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?),
|
||||
Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?),
|
||||
}
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?),
|
||||
Which::World6_1b6 => Model::M6(M6::new(&config, vb)?),
|
||||
}
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
|
@ -1,28 +0,0 @@
|
||||
# candle-segformer
|
||||
|
||||
- [HuggingFace Segformer Model Card][segformer]
|
||||
- [`mit-b0` - An encoder only pretrained model][encoder]
|
||||
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
|
||||
|
||||
## How to run the example
|
||||
|
||||
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
|
||||
|
||||
```bash
|
||||
# run the image classification task
|
||||
cargo run --example segformer classify <path-to-image>
|
||||
# run the segmentation task
|
||||
cargo run --example segformer segment <path-to-image>
|
||||
```
|
||||
|
||||
Example output for classification:
|
||||
|
||||
```text
|
||||
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
|
||||
label: hamburger
|
||||
```
|
||||
|
||||
[pr]: https://github.com/huggingface/candle/pull/1617
|
||||
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
|
||||
[encoder]: https://huggingface.co/nvidia/mit-b0
|
||||
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
|
@ -1,752 +0,0 @@
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"color": "#787878",
|
||||
"label": "wall"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"color": "#B47878",
|
||||
"label": "building;edifice"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"color": "#06E6E6",
|
||||
"label": "sky"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"color": "#503232",
|
||||
"label": "floor;flooring"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"color": "#04C803",
|
||||
"label": "tree"
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"color": "#787850",
|
||||
"label": "ceiling"
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"color": "#8C8C8C",
|
||||
"label": "road;route"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"color": "#CC05FF",
|
||||
"label": "bed"
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"color": "#E6E6E6",
|
||||
"label": "windowpane;window"
|
||||
},
|
||||
{
|
||||
"index": 10,
|
||||
"color": "#04FA07",
|
||||
"label": "grass"
|
||||
},
|
||||
{
|
||||
"index": 11,
|
||||
"color": "#E005FF",
|
||||
"label": "cabinet"
|
||||
},
|
||||
{
|
||||
"index": 12,
|
||||
"color": "#EBFF07",
|
||||
"label": "sidewalk;pavement"
|
||||
},
|
||||
{
|
||||
"index": 13,
|
||||
"color": "#96053D",
|
||||
"label": "person;individual;someone;somebody;mortal;soul"
|
||||
},
|
||||
{
|
||||
"index": 14,
|
||||
"color": "#787846",
|
||||
"label": "earth;ground"
|
||||
},
|
||||
{
|
||||
"index": 15,
|
||||
"color": "#08FF33",
|
||||
"label": "door;double;door"
|
||||
},
|
||||
{
|
||||
"index": 16,
|
||||
"color": "#FF0652",
|
||||
"label": "table"
|
||||
},
|
||||
{
|
||||
"index": 17,
|
||||
"color": "#8FFF8C",
|
||||
"label": "mountain;mount"
|
||||
},
|
||||
{
|
||||
"index": 18,
|
||||
"color": "#CCFF04",
|
||||
"label": "plant;flora;plant;life"
|
||||
},
|
||||
{
|
||||
"index": 19,
|
||||
"color": "#FF3307",
|
||||
"label": "curtain;drape;drapery;mantle;pall"
|
||||
},
|
||||
{
|
||||
"index": 20,
|
||||
"color": "#CC4603",
|
||||
"label": "chair"
|
||||
},
|
||||
{
|
||||
"index": 21,
|
||||
"color": "#0066C8",
|
||||
"label": "car;auto;automobile;machine;motorcar"
|
||||
},
|
||||
{
|
||||
"index": 22,
|
||||
"color": "#3DE6FA",
|
||||
"label": "water"
|
||||
},
|
||||
{
|
||||
"index": 23,
|
||||
"color": "#FF0633",
|
||||
"label": "painting;picture"
|
||||
},
|
||||
{
|
||||
"index": 24,
|
||||
"color": "#0B66FF",
|
||||
"label": "sofa;couch;lounge"
|
||||
},
|
||||
{
|
||||
"index": 25,
|
||||
"color": "#FF0747",
|
||||
"label": "shelf"
|
||||
},
|
||||
{
|
||||
"index": 26,
|
||||
"color": "#FF09E0",
|
||||
"label": "house"
|
||||
},
|
||||
{
|
||||
"index": 27,
|
||||
"color": "#0907E6",
|
||||
"label": "sea"
|
||||
},
|
||||
{
|
||||
"index": 28,
|
||||
"color": "#DCDCDC",
|
||||
"label": "mirror"
|
||||
},
|
||||
{
|
||||
"index": 29,
|
||||
"color": "#FF095C",
|
||||
"label": "rug;carpet;carpeting"
|
||||
},
|
||||
{
|
||||
"index": 30,
|
||||
"color": "#7009FF",
|
||||
"label": "field"
|
||||
},
|
||||
{
|
||||
"index": 31,
|
||||
"color": "#08FFD6",
|
||||
"label": "armchair"
|
||||
},
|
||||
{
|
||||
"index": 32,
|
||||
"color": "#07FFE0",
|
||||
"label": "seat"
|
||||
},
|
||||
{
|
||||
"index": 33,
|
||||
"color": "#FFB806",
|
||||
"label": "fence;fencing"
|
||||
},
|
||||
{
|
||||
"index": 34,
|
||||
"color": "#0AFF47",
|
||||
"label": "desk"
|
||||
},
|
||||
{
|
||||
"index": 35,
|
||||
"color": "#FF290A",
|
||||
"label": "rock;stone"
|
||||
},
|
||||
{
|
||||
"index": 36,
|
||||
"color": "#07FFFF",
|
||||
"label": "wardrobe;closet;press"
|
||||
},
|
||||
{
|
||||
"index": 37,
|
||||
"color": "#E0FF08",
|
||||
"label": "lamp"
|
||||
},
|
||||
{
|
||||
"index": 38,
|
||||
"color": "#6608FF",
|
||||
"label": "bathtub;bathing;tub;bath;tub"
|
||||
},
|
||||
{
|
||||
"index": 39,
|
||||
"color": "#FF3D06",
|
||||
"label": "railing;rail"
|
||||
},
|
||||
{
|
||||
"index": 40,
|
||||
"color": "#FFC207",
|
||||
"label": "cushion"
|
||||
},
|
||||
{
|
||||
"index": 41,
|
||||
"color": "#FF7A08",
|
||||
"label": "base;pedestal;stand"
|
||||
},
|
||||
{
|
||||
"index": 42,
|
||||
"color": "#00FF14",
|
||||
"label": "box"
|
||||
},
|
||||
{
|
||||
"index": 43,
|
||||
"color": "#FF0829",
|
||||
"label": "column;pillar"
|
||||
},
|
||||
{
|
||||
"index": 44,
|
||||
"color": "#FF0599",
|
||||
"label": "signboard;sign"
|
||||
},
|
||||
{
|
||||
"index": 45,
|
||||
"color": "#0633FF",
|
||||
"label": "chest;of;drawers;chest;bureau;dresser"
|
||||
},
|
||||
{
|
||||
"index": 46,
|
||||
"color": "#EB0CFF",
|
||||
"label": "counter"
|
||||
},
|
||||
{
|
||||
"index": 47,
|
||||
"color": "#A09614",
|
||||
"label": "sand"
|
||||
},
|
||||
{
|
||||
"index": 48,
|
||||
"color": "#00A3FF",
|
||||
"label": "sink"
|
||||
},
|
||||
{
|
||||
"index": 49,
|
||||
"color": "#8C8C8C",
|
||||
"label": "skyscraper"
|
||||
},
|
||||
{
|
||||
"index": 50,
|
||||
"color": "#FA0A0F",
|
||||
"label": "fireplace;hearth;open;fireplace"
|
||||
},
|
||||
{
|
||||
"index": 51,
|
||||
"color": "#14FF00",
|
||||
"label": "refrigerator;icebox"
|
||||
},
|
||||
{
|
||||
"index": 52,
|
||||
"color": "#1FFF00",
|
||||
"label": "grandstand;covered;stand"
|
||||
},
|
||||
{
|
||||
"index": 53,
|
||||
"color": "#FF1F00",
|
||||
"label": "path"
|
||||
},
|
||||
{
|
||||
"index": 54,
|
||||
"color": "#FFE000",
|
||||
"label": "stairs;steps"
|
||||
},
|
||||
{
|
||||
"index": 55,
|
||||
"color": "#99FF00",
|
||||
"label": "runway"
|
||||
},
|
||||
{
|
||||
"index": 56,
|
||||
"color": "#0000FF",
|
||||
"label": "case;display;case;showcase;vitrine"
|
||||
},
|
||||
{
|
||||
"index": 57,
|
||||
"color": "#FF4700",
|
||||
"label": "pool;table;billiard;table;snooker;table"
|
||||
},
|
||||
{
|
||||
"index": 58,
|
||||
"color": "#00EBFF",
|
||||
"label": "pillow"
|
||||
},
|
||||
{
|
||||
"index": 59,
|
||||
"color": "#00ADFF",
|
||||
"label": "screen;door;screen"
|
||||
},
|
||||
{
|
||||
"index": 60,
|
||||
"color": "#1F00FF",
|
||||
"label": "stairway;staircase"
|
||||
},
|
||||
{
|
||||
"index": 61,
|
||||
"color": "#0BC8C8",
|
||||
"label": "river"
|
||||
},
|
||||
{
|
||||
"index": 62,
|
||||
"color": "#FF5200",
|
||||
"label": "bridge;span"
|
||||
},
|
||||
{
|
||||
"index": 63,
|
||||
"color": "#00FFF5",
|
||||
"label": "bookcase"
|
||||
},
|
||||
{
|
||||
"index": 64,
|
||||
"color": "#003DFF",
|
||||
"label": "blind;screen"
|
||||
},
|
||||
{
|
||||
"index": 65,
|
||||
"color": "#00FF70",
|
||||
"label": "coffee;table;cocktail;table"
|
||||
},
|
||||
{
|
||||
"index": 66,
|
||||
"color": "#00FF85",
|
||||
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
|
||||
},
|
||||
{
|
||||
"index": 67,
|
||||
"color": "#FF0000",
|
||||
"label": "flower"
|
||||
},
|
||||
{
|
||||
"index": 68,
|
||||
"color": "#FFA300",
|
||||
"label": "book"
|
||||
},
|
||||
{
|
||||
"index": 69,
|
||||
"color": "#FF6600",
|
||||
"label": "hill"
|
||||
},
|
||||
{
|
||||
"index": 70,
|
||||
"color": "#C2FF00",
|
||||
"label": "bench"
|
||||
},
|
||||
{
|
||||
"index": 71,
|
||||
"color": "#008FFF",
|
||||
"label": "countertop"
|
||||
},
|
||||
{
|
||||
"index": 72,
|
||||
"color": "#33FF00",
|
||||
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
|
||||
},
|
||||
{
|
||||
"index": 73,
|
||||
"color": "#0052FF",
|
||||
"label": "palm;palm;tree"
|
||||
},
|
||||
{
|
||||
"index": 74,
|
||||
"color": "#00FF29",
|
||||
"label": "kitchen;island"
|
||||
},
|
||||
{
|
||||
"index": 75,
|
||||
"color": "#00FFAD",
|
||||
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
|
||||
},
|
||||
{
|
||||
"index": 76,
|
||||
"color": "#0A00FF",
|
||||
"label": "swivel;chair"
|
||||
},
|
||||
{
|
||||
"index": 77,
|
||||
"color": "#ADFF00",
|
||||
"label": "boat"
|
||||
},
|
||||
{
|
||||
"index": 78,
|
||||
"color": "#00FF99",
|
||||
"label": "bar"
|
||||
},
|
||||
{
|
||||
"index": 79,
|
||||
"color": "#FF5C00",
|
||||
"label": "arcade;machine"
|
||||
},
|
||||
{
|
||||
"index": 80,
|
||||
"color": "#FF00FF",
|
||||
"label": "hovel;hut;hutch;shack;shanty"
|
||||
},
|
||||
{
|
||||
"index": 81,
|
||||
"color": "#FF00F5",
|
||||
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
|
||||
},
|
||||
{
|
||||
"index": 82,
|
||||
"color": "#FF0066",
|
||||
"label": "towel"
|
||||
},
|
||||
{
|
||||
"index": 83,
|
||||
"color": "#FFAD00",
|
||||
"label": "light;light;source"
|
||||
},
|
||||
{
|
||||
"index": 84,
|
||||
"color": "#FF0014",
|
||||
"label": "truck;motortruck"
|
||||
},
|
||||
{
|
||||
"index": 85,
|
||||
"color": "#FFB8B8",
|
||||
"label": "tower"
|
||||
},
|
||||
{
|
||||
"index": 86,
|
||||
"color": "#001FFF",
|
||||
"label": "chandelier;pendant;pendent"
|
||||
},
|
||||
{
|
||||
"index": 87,
|
||||
"color": "#00FF3D",
|
||||
"label": "awning;sunshade;sunblind"
|
||||
},
|
||||
{
|
||||
"index": 88,
|
||||
"color": "#0047FF",
|
||||
"label": "streetlight;street;lamp"
|
||||
},
|
||||
{
|
||||
"index": 89,
|
||||
"color": "#FF00CC",
|
||||
"label": "booth;cubicle;stall;kiosk"
|
||||
},
|
||||
{
|
||||
"index": 90,
|
||||
"color": "#00FFC2",
|
||||
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
|
||||
},
|
||||
{
|
||||
"index": 91,
|
||||
"color": "#00FF52",
|
||||
"label": "airplane;aeroplane;plane"
|
||||
},
|
||||
{
|
||||
"index": 92,
|
||||
"color": "#000AFF",
|
||||
"label": "dirt;track"
|
||||
},
|
||||
{
|
||||
"index": 93,
|
||||
"color": "#0070FF",
|
||||
"label": "apparel;wearing;apparel;dress;clothes"
|
||||
},
|
||||
{
|
||||
"index": 94,
|
||||
"color": "#3300FF",
|
||||
"label": "pole"
|
||||
},
|
||||
{
|
||||
"index": 95,
|
||||
"color": "#00C2FF",
|
||||
"label": "land;ground;soil"
|
||||
},
|
||||
{
|
||||
"index": 96,
|
||||
"color": "#007AFF",
|
||||
"label": "bannister;banister;balustrade;balusters;handrail"
|
||||
},
|
||||
{
|
||||
"index": 97,
|
||||
"color": "#00FFA3",
|
||||
"label": "escalator;moving;staircase;moving;stairway"
|
||||
},
|
||||
{
|
||||
"index": 98,
|
||||
"color": "#FF9900",
|
||||
"label": "ottoman;pouf;pouffe;puff;hassock"
|
||||
},
|
||||
{
|
||||
"index": 99,
|
||||
"color": "#00FF0A",
|
||||
"label": "bottle"
|
||||
},
|
||||
{
|
||||
"index": 100,
|
||||
"color": "#FF7000",
|
||||
"label": "buffet;counter;sideboard"
|
||||
},
|
||||
{
|
||||
"index": 101,
|
||||
"color": "#8FFF00",
|
||||
"label": "poster;posting;placard;notice;bill;card"
|
||||
},
|
||||
{
|
||||
"index": 102,
|
||||
"color": "#5200FF",
|
||||
"label": "stage"
|
||||
},
|
||||
{
|
||||
"index": 103,
|
||||
"color": "#A3FF00",
|
||||
"label": "van"
|
||||
},
|
||||
{
|
||||
"index": 104,
|
||||
"color": "#FFEB00",
|
||||
"label": "ship"
|
||||
},
|
||||
{
|
||||
"index": 105,
|
||||
"color": "#08B8AA",
|
||||
"label": "fountain"
|
||||
},
|
||||
{
|
||||
"index": 106,
|
||||
"color": "#8500FF",
|
||||
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
|
||||
},
|
||||
{
|
||||
"index": 107,
|
||||
"color": "#00FF5C",
|
||||
"label": "canopy"
|
||||
},
|
||||
{
|
||||
"index": 108,
|
||||
"color": "#B800FF",
|
||||
"label": "washer;automatic;washer;washing;machine"
|
||||
},
|
||||
{
|
||||
"index": 109,
|
||||
"color": "#FF001F",
|
||||
"label": "plaything;toy"
|
||||
},
|
||||
{
|
||||
"index": 110,
|
||||
"color": "#00B8FF",
|
||||
"label": "swimming;pool;swimming;bath;natatorium"
|
||||
},
|
||||
{
|
||||
"index": 111,
|
||||
"color": "#00D6FF",
|
||||
"label": "stool"
|
||||
},
|
||||
{
|
||||
"index": 112,
|
||||
"color": "#FF0070",
|
||||
"label": "barrel;cask"
|
||||
},
|
||||
{
|
||||
"index": 113,
|
||||
"color": "#5CFF00",
|
||||
"label": "basket;handbasket"
|
||||
},
|
||||
{
|
||||
"index": 114,
|
||||
"color": "#00E0FF",
|
||||
"label": "waterfall;falls"
|
||||
},
|
||||
{
|
||||
"index": 115,
|
||||
"color": "#70E0FF",
|
||||
"label": "tent;collapsible;shelter"
|
||||
},
|
||||
{
|
||||
"index": 116,
|
||||
"color": "#46B8A0",
|
||||
"label": "bag"
|
||||
},
|
||||
{
|
||||
"index": 117,
|
||||
"color": "#A300FF",
|
||||
"label": "minibike;motorbike"
|
||||
},
|
||||
{
|
||||
"index": 118,
|
||||
"color": "#9900FF",
|
||||
"label": "cradle"
|
||||
},
|
||||
{
|
||||
"index": 119,
|
||||
"color": "#47FF00",
|
||||
"label": "oven"
|
||||
},
|
||||
{
|
||||
"index": 120,
|
||||
"color": "#FF00A3",
|
||||
"label": "ball"
|
||||
},
|
||||
{
|
||||
"index": 121,
|
||||
"color": "#FFCC00",
|
||||
"label": "food;solid;food"
|
||||
},
|
||||
{
|
||||
"index": 122,
|
||||
"color": "#FF008F",
|
||||
"label": "step;stair"
|
||||
},
|
||||
{
|
||||
"index": 123,
|
||||
"color": "#00FFEB",
|
||||
"label": "tank;storage;tank"
|
||||
},
|
||||
{
|
||||
"index": 124,
|
||||
"color": "#85FF00",
|
||||
"label": "trade;name;brand;name;brand;marque"
|
||||
},
|
||||
{
|
||||
"index": 125,
|
||||
"color": "#FF00EB",
|
||||
"label": "microwave;microwave;oven"
|
||||
},
|
||||
{
|
||||
"index": 126,
|
||||
"color": "#F500FF",
|
||||
"label": "pot;flowerpot"
|
||||
},
|
||||
{
|
||||
"index": 127,
|
||||
"color": "#FF007A",
|
||||
"label": "animal;animate;being;beast;brute;creature;fauna"
|
||||
},
|
||||
{
|
||||
"index": 128,
|
||||
"color": "#FFF500",
|
||||
"label": "bicycle;bike;wheel;cycle"
|
||||
},
|
||||
{
|
||||
"index": 129,
|
||||
"color": "#0ABED4",
|
||||
"label": "lake"
|
||||
},
|
||||
{
|
||||
"index": 130,
|
||||
"color": "#D6FF00",
|
||||
"label": "dishwasher;dish;washer;dishwashing;machine"
|
||||
},
|
||||
{
|
||||
"index": 131,
|
||||
"color": "#00CCFF",
|
||||
"label": "screen;silver;screen;projection;screen"
|
||||
},
|
||||
{
|
||||
"index": 132,
|
||||
"color": "#1400FF",
|
||||
"label": "blanket;cover"
|
||||
},
|
||||
{
|
||||
"index": 133,
|
||||
"color": "#FFFF00",
|
||||
"label": "sculpture"
|
||||
},
|
||||
{
|
||||
"index": 134,
|
||||
"color": "#0099FF",
|
||||
"label": "hood;exhaust;hood"
|
||||
},
|
||||
{
|
||||
"index": 135,
|
||||
"color": "#0029FF",
|
||||
"label": "sconce"
|
||||
},
|
||||
{
|
||||
"index": 136,
|
||||
"color": "#00FFCC",
|
||||
"label": "vase"
|
||||
},
|
||||
{
|
||||
"index": 137,
|
||||
"color": "#2900FF",
|
||||
"label": "traffic;light;traffic;signal;stoplight"
|
||||
},
|
||||
{
|
||||
"index": 138,
|
||||
"color": "#29FF00",
|
||||
"label": "tray"
|
||||
},
|
||||
{
|
||||
"index": 139,
|
||||
"color": "#AD00FF",
|
||||
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
|
||||
},
|
||||
{
|
||||
"index": 140,
|
||||
"color": "#00F5FF",
|
||||
"label": "fan"
|
||||
},
|
||||
{
|
||||
"index": 141,
|
||||
"color": "#4700FF",
|
||||
"label": "pier;wharf;wharfage;dock"
|
||||
},
|
||||
{
|
||||
"index": 142,
|
||||
"color": "#7A00FF",
|
||||
"label": "crt;screen"
|
||||
},
|
||||
{
|
||||
"index": 143,
|
||||
"color": "#00FFB8",
|
||||
"label": "plate"
|
||||
},
|
||||
{
|
||||
"index": 144,
|
||||
"color": "#005CFF",
|
||||
"label": "monitor;monitoring;device"
|
||||
},
|
||||
{
|
||||
"index": 145,
|
||||
"color": "#B8FF00",
|
||||
"label": "bulletin;board;notice;board"
|
||||
},
|
||||
{
|
||||
"index": 146,
|
||||
"color": "#0085FF",
|
||||
"label": "shower"
|
||||
},
|
||||
{
|
||||
"index": 147,
|
||||
"color": "#FFD600",
|
||||
"label": "radiator"
|
||||
},
|
||||
{
|
||||
"index": 148,
|
||||
"color": "#19C2C2",
|
||||
"label": "glass;drinking;glass"
|
||||
},
|
||||
{
|
||||
"index": 149,
|
||||
"color": "#66FF00",
|
||||
"label": "clock"
|
||||
},
|
||||
{
|
||||
"index": 150,
|
||||
"color": "#5C00FF",
|
||||
"label": "flag"
|
||||
}
|
||||
]
|
@ -1,155 +0,0 @@
|
||||
use candle::Device;
|
||||
use candle::Module;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::segformer::{
|
||||
Config, ImageClassificationModel, SemanticSegmentationModel,
|
||||
};
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use image::Rgb;
|
||||
use imageproc::integral_image::ArrayData;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[clap(about, version, long_about = None)]
|
||||
struct CliArgs {
|
||||
#[arg(long, help = "use cpu")]
|
||||
cpu: bool,
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
#[derive(Args, Debug)]
|
||||
struct SegmentationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(
|
||||
long,
|
||||
help = "path to the label file in json format",
|
||||
default_value = "candle-examples/examples/segformer/assets/labels.json"
|
||||
)]
|
||||
label_path: PathBuf,
|
||||
#[arg(long, help = "path to for the output mask image")]
|
||||
output_path: PathBuf,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
struct ClassificationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "paolinox/segformer-finetuned-food101"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Commands {
|
||||
Segment(SegmentationArgs),
|
||||
Classify(ClassificationArgs),
|
||||
}
|
||||
|
||||
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
|
||||
println!("loading model {} via huggingface hub", model_name);
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name.clone());
|
||||
let model_file = api.get("model.safetensors")?;
|
||||
println!("model {} downloaded and loaded", model_name);
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
|
||||
let config = std::fs::read_to_string(api.get("config.json")?)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
println!("{:?}", config);
|
||||
Ok((vb, config))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LabelItem {
|
||||
index: u32,
|
||||
color: String,
|
||||
}
|
||||
|
||||
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let label_file = std::fs::read_to_string(&args.label_path)?;
|
||||
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
|
||||
let label_colors: HashMap<u32, Rgb<u8>> = label_items
|
||||
.iter()
|
||||
.map(|x| {
|
||||
(x.index - 1, {
|
||||
let color = x.color.trim_start_matches('#');
|
||||
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
|
||||
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
|
||||
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
|
||||
Rgb([r, g, b])
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = label_items.len();
|
||||
|
||||
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
|
||||
let segmentations = model.forward(&image)?;
|
||||
|
||||
// generate a mask image
|
||||
let mask = &segmentations.squeeze(0)?.argmax(0)?;
|
||||
let (h, w) = mask.dims2()?;
|
||||
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
|
||||
let mask = mask
|
||||
.iter()
|
||||
.flat_map(|x| label_colors[x].data())
|
||||
.collect::<Vec<u8>>();
|
||||
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
|
||||
// resize
|
||||
let mask = image::DynamicImage::from(mask);
|
||||
let mask = mask.resize_to_fill(
|
||||
w as u32 * 4,
|
||||
h as u32 * 4,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
mask.save(args.output_path.clone())?;
|
||||
println!("mask image saved to {:?}", args.output_path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = 7;
|
||||
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
|
||||
let classification = model.forward(&image)?;
|
||||
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
|
||||
let classification = classification.squeeze(0)?;
|
||||
println!(
|
||||
"classification logits {:?}",
|
||||
classification.to_vec1::<f32>()?
|
||||
);
|
||||
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
|
||||
let label_id = format!("{}", label_id);
|
||||
println!("label: {}", config.id2label[&label_id]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = CliArgs::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
if let Commands::Segment(args) = args.command {
|
||||
segmentation_task(args, &device)?
|
||||
} else if let Commands::Classify(args) = args.command {
|
||||
classification_task(args, &device)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -57,7 +57,7 @@ The downside is some long compilation time. You can set the
|
||||
`/home/user/.candle` to ensures that the compilation artifacts are properly
|
||||
cached.
|
||||
|
||||
Enabling flash-attention requires both a feature flag, `--features flash-attn`
|
||||
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
|
||||
and using the command line flag `--use-flash-attn`.
|
||||
|
||||
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs
|
||||
|
@ -1,253 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::starcoder2::Model;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => "bigcode/starcoder2-3b".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let config_file = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let tokenizer_file = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => vec![repo.get("model.safetensors")?],
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
|
||||
pub fn normalize_loudness(
|
||||
wav: &Tensor,
|
||||
sample_rate: u32,
|
||||
loudness_compressor: bool,
|
||||
) -> Result<Tensor> {
|
||||
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
|
||||
if energy < 2e-3 {
|
||||
return Ok(wav.clone());
|
||||
}
|
||||
let wav_array = wav.to_vec1::<f32>()?;
|
||||
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
|
||||
meter.push(wav_array.into_iter());
|
||||
let power = meter.as_100ms_windows();
|
||||
let loudness = match crate::bs1770::gated_mean(power) {
|
||||
None => return Ok(wav.clone()),
|
||||
Some(gp) => gp.loudness_lkfs() as f64,
|
||||
};
|
||||
let delta_loudness = -14. - loudness;
|
||||
let gain = 10f64.powf(delta_loudness / 20.);
|
||||
let wav = (wav * gain)?;
|
||||
if loudness_compressor {
|
||||
wav.tanh()
|
||||
} else {
|
||||
Ok(wav)
|
||||
}
|
||||
}
|
@ -1,506 +0,0 @@
|
||||
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
|
||||
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
|
||||
// Copyright 2020 Ruud van Asseldonk
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// A copy of the License has been included in the root of the repository.
|
||||
|
||||
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
|
||||
//!
|
||||
//! This library offers the building blocks to perform BS.1770 loudness
|
||||
//! measurements, but you need to put the pieces together yourself.
|
||||
//!
|
||||
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
|
||||
//!
|
||||
//! # Stereo integrated loudness example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
|
||||
//! # [vec![0; 48_000], vec![0; 48_000]]
|
||||
//! # }
|
||||
//! #
|
||||
//! let sample_rate_hz = 44_100;
|
||||
//! let bits_per_sample = 16;
|
||||
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
|
||||
//!
|
||||
//! // When converting integer samples to float, note that the maximum amplitude
|
||||
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
|
||||
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
//!
|
||||
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
|
||||
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
//! meter.into_100ms_windows()
|
||||
//! }).collect();
|
||||
//!
|
||||
//! let stereo_power = bs1770::reduce_stereo(
|
||||
//! channel_power[0].as_ref(),
|
||||
//! channel_power[1].as_ref(),
|
||||
//! );
|
||||
//!
|
||||
//! let gated_power = bs1770::gated_mean(
|
||||
//! stereo_power.as_ref()
|
||||
//! ).unwrap_or(bs1770::Power(0.0));
|
||||
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
|
||||
//! ```
|
||||
|
||||
use std::f32;
|
||||
|
||||
/// Coefficients for a 2nd-degree infinite impulse response filter.
|
||||
///
|
||||
/// Coefficient a0 is implicitly 1.0.
|
||||
#[derive(Clone)]
|
||||
struct Filter {
|
||||
a1: f32,
|
||||
a2: f32,
|
||||
b0: f32,
|
||||
b1: f32,
|
||||
b2: f32,
|
||||
|
||||
// The past two input and output samples.
|
||||
x1: f32,
|
||||
x2: f32,
|
||||
y1: f32,
|
||||
y2: f32,
|
||||
}
|
||||
|
||||
impl Filter {
|
||||
/// Stage 1 of th BS.1770-4 pre-filter.
|
||||
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let gain_db = 3.999_843_8;
|
||||
let q = 0.707_175_25;
|
||||
let center_hz = 1_681.974_5;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
let vh = 10.0_f32.powf(gain_db / 20.0);
|
||||
let vb = vh.powf(0.499_666_78);
|
||||
let a0 = 1.0 + k / q + k * k;
|
||||
Filter {
|
||||
b0: (vh + vb * k / q + k * k) / a0,
|
||||
b1: 2.0 * (k * k - vh) / a0,
|
||||
b2: (vh - vb * k / q + k * k) / a0,
|
||||
a1: 2.0 * (k * k - 1.0) / a0,
|
||||
a2: (1.0 - k / q + k * k) / a0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Stage 2 of th BS.1770-4 pre-filter.
|
||||
pub fn high_pass(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let q = 0.500_327_05;
|
||||
let center_hz = 38.135_47;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
Filter {
|
||||
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
|
||||
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
|
||||
b0: 1.0,
|
||||
b1: -2.0,
|
||||
b2: 1.0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed the next input sample, get the next output sample.
|
||||
#[inline(always)]
|
||||
pub fn apply(&mut self, x0: f32) -> f32 {
|
||||
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
|
||||
- self.a1 * self.y1
|
||||
- self.a2 * self.y2;
|
||||
|
||||
self.x2 = self.x1;
|
||||
self.x1 = x0;
|
||||
self.y2 = self.y1;
|
||||
self.y1 = y0;
|
||||
|
||||
y0
|
||||
}
|
||||
}
|
||||
|
||||
/// Compensated sum, for summing many values of different orders of magnitude
|
||||
/// accurately.
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
struct Sum {
|
||||
sum: f32,
|
||||
residue: f32,
|
||||
}
|
||||
|
||||
impl Sum {
|
||||
#[inline(always)]
|
||||
fn zero() -> Sum {
|
||||
Sum {
|
||||
sum: 0.0,
|
||||
residue: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add(&mut self, x: f32) {
|
||||
let sum = self.sum + (self.residue + x);
|
||||
self.residue = (self.residue + x) - (sum - self.sum);
|
||||
self.sum = sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// The mean of the squares of the K-weighted samples in a window of time.
|
||||
///
|
||||
/// K-weighted power is equivalent to K-weighted loudness, the only difference
|
||||
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
|
||||
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
|
||||
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
|
||||
///
|
||||
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
|
||||
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
|
||||
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
|
||||
/// relative to Full Scale). Loudness units are related to decibels in the
|
||||
/// following sense: boosting a signal that has a loudness of
|
||||
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
|
||||
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
|
||||
/// bring the loudness to 0 LUFS.
|
||||
///
|
||||
/// K-weighting refers to a high-shelf and high-pass filter that model the
|
||||
/// effect that humans perceive a certain amount of power in low frequencies to
|
||||
/// be less loud than the same amount of power in higher frequencies. In this
|
||||
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
|
||||
///
|
||||
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
|
||||
/// mean square of the samples, if no input samples exceeded the full scale, the
|
||||
/// power will be in the range [0.0, 1.0]. However, the power delivered by
|
||||
/// multiple channels, which is a weighted sum over individual channel powers,
|
||||
/// can exceed this range, because the weighted sum is not normalized.
|
||||
#[derive(Copy, Clone, PartialEq, PartialOrd)]
|
||||
pub struct Power(pub f32);
|
||||
|
||||
impl Power {
|
||||
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
|
||||
///
|
||||
/// This is the inverse of `loudness_lkfs`.
|
||||
pub fn from_lkfs(lkfs: f32) -> Power {
|
||||
// The inverse of the formula below.
|
||||
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
|
||||
}
|
||||
|
||||
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
|
||||
///
|
||||
/// This is the inverse of `from_lkfs`.
|
||||
pub fn loudness_lkfs(&self) -> f32 {
|
||||
// Equation 2 (p.5) of BS.1770-4.
|
||||
-0.691 + 10.0 * self.0.log10()
|
||||
}
|
||||
}
|
||||
|
||||
/// A `T` value for non-overlapping windows of audio, 100ms in length.
|
||||
///
|
||||
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
|
||||
/// for non-overlapping windows of 100ms duration.
|
||||
///
|
||||
/// These non-overlapping 100ms windows can later be combined into overlapping
|
||||
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
|
||||
/// to perform a gated measurement, or they can be combined into even larger
|
||||
/// windows for a momentary loudness measurement.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Windows100ms<T> {
|
||||
pub inner: T,
|
||||
}
|
||||
|
||||
impl<T> Windows100ms<T> {
|
||||
/// Wrap a new empty vector.
|
||||
pub fn new() -> Windows100ms<Vec<T>> {
|
||||
Windows100ms { inner: Vec::new() }
|
||||
}
|
||||
|
||||
/// Apply `as_ref` to the inner value.
|
||||
pub fn as_ref(&self) -> Windows100ms<&[Power]>
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply `as_mut` to the inner value.
|
||||
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
|
||||
where
|
||||
T: AsMut<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
/// Apply `len` to the inner value.
|
||||
pub fn len(&self) -> usize
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
self.inner.as_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
|
||||
///
|
||||
/// # Output
|
||||
///
|
||||
/// The output of the meter is an intermediate result in the form of power for
|
||||
/// 100ms non-overlapping windows. The windows need to be processed further to
|
||||
/// get one of the instantaneous, momentary, and integrated loudness
|
||||
/// measurements defined in BS.1770.
|
||||
///
|
||||
/// The windows can also be inspected directly; the data is meaningful
|
||||
/// on its own (the K-weighted power delivered in that window of time), but it
|
||||
/// is not something that BS.1770 defines a term for.
|
||||
///
|
||||
/// # Multichannel audio
|
||||
///
|
||||
/// To perform a loudness measurement of multichannel audio, construct a
|
||||
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
|
||||
/// with e.g. `reduce_stereo`.
|
||||
///
|
||||
/// # Instantaneous loudness
|
||||
///
|
||||
/// The instantaneous loudness is the power over a 400ms window, so you can
|
||||
/// average four 100ms windows. No special functionality is implemented to help
|
||||
/// with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Momentary loudness
|
||||
///
|
||||
/// The momentary loudness is the power over a 3-second window, so you can
|
||||
/// average thirty 100ms windows. No special functionality is implemented to
|
||||
/// help with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Integrated loudness
|
||||
///
|
||||
/// Use `gated_mean` to perform an integrated loudness measurement:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
|
||||
/// # let sample_rate_hz = 44_100;
|
||||
/// # let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
|
||||
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
|
||||
/// .unwrap_or(bs1770::Power(0.0))
|
||||
/// .loudness_lkfs();
|
||||
/// ```
|
||||
///
|
||||
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelLoudnessMeter {
|
||||
/// The number of samples that fit in 100ms of audio.
|
||||
samples_per_100ms: u32,
|
||||
|
||||
/// Stage 1 filter (head effects, high shelf).
|
||||
filter_stage1: Filter,
|
||||
|
||||
/// Stage 2 filter (high-pass).
|
||||
filter_stage2: Filter,
|
||||
|
||||
/// Sum of the squares over non-overlapping windows of 100ms.
|
||||
windows: Windows100ms<Vec<Power>>,
|
||||
|
||||
/// The number of samples in the current unfinished window.
|
||||
count: u32,
|
||||
|
||||
/// The sum of the squares of the samples in the current unfinished window.
|
||||
square_sum: Sum,
|
||||
}
|
||||
|
||||
impl ChannelLoudnessMeter {
|
||||
/// Construct a new loudness meter for the given sample rate.
|
||||
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
|
||||
ChannelLoudnessMeter {
|
||||
samples_per_100ms: sample_rate_hz / 10,
|
||||
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
|
||||
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
|
||||
windows: Windows100ms::new(),
|
||||
count: 0,
|
||||
square_sum: Sum::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed input samples for loudness analysis.
|
||||
///
|
||||
/// # Full scale
|
||||
///
|
||||
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
|
||||
/// input consists of signed integer samples, you can convert as follows:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
|
||||
/// # let bits_per_sample = 16_usize;
|
||||
/// # let samples = &[0_i16];
|
||||
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
|
||||
/// // one bit is the sign bit.
|
||||
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
/// ```
|
||||
///
|
||||
/// # Repeated calls
|
||||
///
|
||||
/// You can call `push` multiple times to feed multiple batches of samples.
|
||||
/// This is equivalent to feeding a single chained iterator. The leftover of
|
||||
/// samples that did not fill a full 100ms window is not discarded:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::ChannelLoudnessMeter;
|
||||
/// let sample_rate_hz = 44_100;
|
||||
/// let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
///
|
||||
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 0);
|
||||
///
|
||||
/// meter.push(iter::once(0.0));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 1);
|
||||
/// ```
|
||||
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
|
||||
let normalizer = 1.0 / self.samples_per_100ms as f32;
|
||||
|
||||
// LLVM, if you could go ahead and inline those apply calls, and then
|
||||
// unroll and vectorize the loop, that'd be terrific.
|
||||
for x in samples {
|
||||
let y = self.filter_stage1.apply(x);
|
||||
let z = self.filter_stage2.apply(y);
|
||||
|
||||
self.square_sum.add(z * z);
|
||||
self.count += 1;
|
||||
|
||||
// TODO: Should this branch be marked cold?
|
||||
if self.count == self.samples_per_100ms {
|
||||
let mean_squares = Power(self.square_sum.sum * normalizer);
|
||||
self.windows.inner.push(mean_squares);
|
||||
// We intentionally do not reset the residue. That way, leftover
|
||||
// energy from this window is not lost, so for the file overall,
|
||||
// the sum remains more accurate.
|
||||
self.square_sum.sum = 0.0;
|
||||
self.count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a reference to the 100ms windows analyzed so far.
|
||||
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
|
||||
self.windows.as_ref()
|
||||
}
|
||||
|
||||
/// Return all 100ms windows analyzed so far.
|
||||
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
|
||||
self.windows
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine power for multiple channels by taking a weighted sum.
|
||||
///
|
||||
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
|
||||
/// sum over channels which is not normalized. This means that a stereo signal
|
||||
/// is inherently louder than a mono signal. For a mono signal played back on
|
||||
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
|
||||
/// in the same signal for both channels.
|
||||
pub fn reduce_stereo(
|
||||
left: Windows100ms<&[Power]>,
|
||||
right: Windows100ms<&[Power]>,
|
||||
) -> Windows100ms<Vec<Power>> {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
let mut result = Vec::with_capacity(left.len());
|
||||
for (l, r) in left.inner.iter().zip(right.inner) {
|
||||
result.push(Power(l.0 + r.0));
|
||||
}
|
||||
Windows100ms { inner: result }
|
||||
}
|
||||
|
||||
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
|
||||
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
for (l, r) in left.inner.iter_mut().zip(right.inner) {
|
||||
l.0 += r.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
|
||||
///
|
||||
/// The integrated loudness measurement is not just the average power over the
|
||||
/// entire signal. BS.1770-4 defines defines two stages of gating that exclude
|
||||
/// parts of the signal, to ensure that silent parts do not contribute to the
|
||||
/// loudness measurment. This function performs that gating, and returns the
|
||||
/// average power over the windows that were not excluded.
|
||||
///
|
||||
/// The result of this function is the integrated loudness measurement.
|
||||
///
|
||||
/// When no signal remains after applying the gate, this function returns
|
||||
/// `None`. In particular, this happens when all of the signal is softer than
|
||||
/// -70 LKFS, including a signal that consists of pure silence.
|
||||
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
|
||||
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
|
||||
|
||||
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
|
||||
let absolute_threshold = Power::from_lkfs(-70.0);
|
||||
|
||||
// Iterate over all 400ms windows.
|
||||
for window in windows_100ms.inner.windows(4) {
|
||||
// Note that the sum over channels has already been performed at this point.
|
||||
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
|
||||
|
||||
if gating_block_power > absolute_threshold {
|
||||
gating_blocks.push(gating_block_power);
|
||||
}
|
||||
}
|
||||
|
||||
if gating_blocks.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Compute the loudness after applying the absolute gate, in order to
|
||||
// determine the threshold for the relative gate.
|
||||
let mut sum_power = Sum::zero();
|
||||
for &gating_block_power in &gating_blocks {
|
||||
sum_power.add(gating_block_power.0);
|
||||
}
|
||||
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
|
||||
|
||||
// Stage 2: Apply the relative gate.
|
||||
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
|
||||
let mut sum_power = Sum::zero();
|
||||
let mut n_blocks = 0_usize;
|
||||
for &gating_block_power in &gating_blocks {
|
||||
if gating_block_power > relative_threshold {
|
||||
sum_power.add(gating_block_power.0);
|
||||
n_blocks += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if n_blocks == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
|
||||
Some(relative_gated_power)
|
||||
}
|
@ -1,9 +1,6 @@
|
||||
pub mod audio;
|
||||
pub mod bs1770;
|
||||
pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
pub mod wav;
|
||||
|
||||
use candle::utils::{cuda_is_available, metal_is_available};
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
@ -40,7 +40,7 @@ impl TokenOutputStream {
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
|
@ -1,56 +0,0 @@
|
||||
use std::io::prelude::*;
|
||||
|
||||
pub trait Sample {
|
||||
fn to_i16(&self) -> i16;
|
||||
}
|
||||
|
||||
impl Sample for f32 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for f64 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for i16 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_pcm_as_wav<W: Write, S: Sample>(
|
||||
w: &mut W,
|
||||
samples: &[S],
|
||||
sample_rate: u32,
|
||||
) -> std::io::Result<()> {
|
||||
let len = 12u32; // header
|
||||
let len = len + 24u32; // fmt
|
||||
let len = len + samples.len() as u32 * 2 + 8; // data
|
||||
let n_channels = 1u16;
|
||||
let bytes_per_second = sample_rate * 2 * n_channels as u32;
|
||||
w.write_all(b"RIFF")?;
|
||||
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
|
||||
w.write_all(b"WAVE")?;
|
||||
|
||||
// Format block
|
||||
w.write_all(b"fmt ")?;
|
||||
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
|
||||
w.write_all(&1u16.to_le_bytes())?; // PCM
|
||||
w.write_all(&n_channels.to_le_bytes())?; // one channel
|
||||
w.write_all(&sample_rate.to_le_bytes())?;
|
||||
w.write_all(&bytes_per_second.to_le_bytes())?;
|
||||
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
|
||||
w.write_all(&16u16.to_le_bytes())?; // bits per sample
|
||||
|
||||
// Data block
|
||||
w.write_all(b"data")?;
|
||||
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
|
||||
for sample in samples.iter() {
|
||||
w.write_all(&sample.to_i16().to_le_bytes())?
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.4.1"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.1" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.4.1"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -4,7 +4,6 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.4.1"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -5,7 +5,6 @@ use serde::Deserialize;
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Activation {
|
||||
#[default]
|
||||
#[serde(alias = "gelu")]
|
||||
Gelu,
|
||||
#[serde(alias = "gelu_new")]
|
||||
NewGelu,
|
||||
@ -20,8 +19,6 @@ pub enum Activation {
|
||||
HardSwish,
|
||||
Elu(f64),
|
||||
LeakyRelu(f64),
|
||||
#[serde(alias = "gelu_pytorch_tanh")]
|
||||
GeluPytorchTanh,
|
||||
}
|
||||
|
||||
impl super::Module for Activation {
|
||||
@ -41,7 +38,6 @@ impl super::Module for Activation {
|
||||
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
|
||||
Self::GeluPytorchTanh => xs.gelu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ pub struct ConvTranspose1dConfig {
|
||||
pub output_padding: usize,
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
// TODO: support groups.
|
||||
}
|
||||
|
||||
impl Default for ConvTranspose1dConfig {
|
||||
@ -86,7 +86,6 @@ impl Default for ConvTranspose1dConfig {
|
||||
output_padding: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -110,14 +109,6 @@ impl ConvTranspose1d {
|
||||
pub fn config(&self) -> &ConvTranspose1dConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn weight(&self) -> &Tensor {
|
||||
&self.weight
|
||||
}
|
||||
|
||||
pub fn bias(&self) -> Option<&Tensor> {
|
||||
self.bias.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for ConvTranspose1d {
|
||||
@ -128,13 +119,12 @@ impl crate::Module for ConvTranspose1d {
|
||||
self.config.output_padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
let b = bias.dims1()?;
|
||||
let bias = bias.reshape((1, b, 1))?;
|
||||
let bias = bias.reshape((1, b, 1, 1))?;
|
||||
Ok(x.broadcast_add(&bias)?)
|
||||
}
|
||||
}
|
||||
@ -268,14 +258,6 @@ impl ConvTranspose2d {
|
||||
pub fn config(&self) -> &ConvTranspose2dConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn weight(&self) -> &Tensor {
|
||||
&self.weight
|
||||
}
|
||||
|
||||
pub fn bias(&self) -> Option<&Tensor> {
|
||||
self.bias.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for ConvTranspose2d {
|
||||
@ -320,22 +302,6 @@ pub fn conv1d(
|
||||
Ok(Conv1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
pub fn conv1d_no_bias(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: Conv1dConfig,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vb.get_with_hints(
|
||||
(out_channels, in_channels / cfg.groups, kernel_size),
|
||||
"weight",
|
||||
init_ws,
|
||||
)?;
|
||||
Ok(Conv1d::new(ws, None, cfg))
|
||||
}
|
||||
|
||||
pub fn conv_transpose1d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
@ -348,11 +314,7 @@ pub fn conv_transpose1d(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vb.get_with_hints(
|
||||
(in_channels, out_channels / cfg.groups, kernel_size),
|
||||
"weight",
|
||||
init,
|
||||
)?;
|
||||
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
|
||||
let bs = vb.get_with_hints(out_channels, "bias", init)?;
|
||||
Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
@ -369,11 +331,7 @@ pub fn conv_transpose1d_no_bias(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vb.get_with_hints(
|
||||
(in_channels, out_channels / cfg.groups, kernel_size),
|
||||
"weight",
|
||||
init,
|
||||
)?;
|
||||
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
|
||||
Ok(ConvTranspose1d::new(ws, None, cfg))
|
||||
}
|
||||
|
||||
|
@ -19,16 +19,15 @@ pub mod var_map;
|
||||
pub use activation::{prelu, Activation, PReLU};
|
||||
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
|
||||
pub use conv::{
|
||||
conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias,
|
||||
conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig,
|
||||
ConvTranspose1d, ConvTranspose1dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
||||
conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,
|
||||
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
|
||||
};
|
||||
pub use embedding::{embedding, Embedding};
|
||||
pub use func::{func, func_t, Func, FuncT};
|
||||
pub use group_norm::{group_norm, GroupNorm};
|
||||
pub use init::Init;
|
||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
pub use linear::{linear, linear_b, linear_no_bias, Linear};
|
||||
pub use linear::{linear, linear_no_bias, Linear};
|
||||
pub use ops::Dropout;
|
||||
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
|
||||
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
||||
|
@ -57,34 +57,21 @@ impl super::Module for Linear {
|
||||
/// Create or initialize a new linear layer.
|
||||
///
|
||||
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
|
||||
pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
|
||||
pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let init_bs = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
|
||||
let bs = vs.get_with_hints(out_dim, "bias", init_bs)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
/// Create or initialize a new linear layer without biases.
|
||||
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
|
||||
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
Ok(Linear::new(ws, None))
|
||||
}
|
||||
|
||||
pub fn linear_b(
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
bias: bool,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Linear> {
|
||||
if bias {
|
||||
linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
@ -197,7 +197,7 @@ impl RNN for LSTM {
|
||||
|
||||
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
|
||||
let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
|
||||
Tensor::stack(&states, 1)
|
||||
Tensor::cat(&states, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,7 +70,7 @@ impl VarMap {
|
||||
///
|
||||
/// If an error is returned, some of the variables might have already been set to their new
|
||||
/// values.
|
||||
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<str>, V: AsRef<Tensor>>(
|
||||
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
|
||||
&mut self,
|
||||
iter: I,
|
||||
) -> Result<()> {
|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
use candle::test_utils::{to_vec0_round, to_vec2_round};
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, Var};
|
||||
use candle::{Device, Tensor, Var};
|
||||
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
|
||||
|
||||
#[test]
|
||||
@ -121,40 +121,3 @@ fn adamw_linear_regression() -> Result<()> {
|
||||
assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adamw_linear_regression_varmap() -> Result<()> {
|
||||
use candle_nn::Init::Const;
|
||||
|
||||
// Similar as the previous test but using a VarMap.
|
||||
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
|
||||
let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
|
||||
let gen = Linear::new(w_gen, Some(b_gen));
|
||||
let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
|
||||
let sample_ys = gen.forward(&sample_xs)?;
|
||||
|
||||
let mut var_map = candle_nn::VarMap::new();
|
||||
|
||||
let w = var_map.get((1, 2), "w", Const(0.), DType::F32, &Device::Cpu)?;
|
||||
let b = var_map.get((), "b", Const(0.), DType::F32, &Device::Cpu)?;
|
||||
let params = ParamsAdamW {
|
||||
lr: 0.1,
|
||||
..Default::default()
|
||||
};
|
||||
let mut opt = AdamW::new(var_map.all_vars(), params)?;
|
||||
let lin = Linear::new(w, Some(b));
|
||||
for _step in 0..100 {
|
||||
let ys = lin.forward(&sample_xs)?;
|
||||
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
|
||||
opt.backward_step(&loss)?;
|
||||
}
|
||||
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[2.7257, 0.7097]]);
|
||||
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 0.7873);
|
||||
|
||||
var_map.set([("w", Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?)].into_iter())?;
|
||||
var_map.set([("b", Tensor::ones((), DType::F32, &Device::Cpu)?)].into_iter())?;
|
||||
|
||||
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[0., 0.]]);
|
||||
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 1.);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.4.1"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.4.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.4.1" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.4.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.4.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -832,49 +832,7 @@ fn test_flatten_operation() -> Result<()> {
|
||||
// #[test]
|
||||
|
||||
// "Shape"
|
||||
#[test]
|
||||
fn test_shape_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Shape".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec1::<i64>()?;
|
||||
assert_eq!(results, vec![2, 2]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Conv"
|
||||
// #[test]
|
||||
@ -883,452 +841,31 @@ fn test_shape_operation() -> Result<()> {
|
||||
// #[test]
|
||||
|
||||
// "Abs"
|
||||
#[test]
|
||||
fn test_abs_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Abs".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(
|
||||
vec![-1.0f32, 2.0f32, -3.0f32, 4.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Cos"
|
||||
#[test]
|
||||
fn test_cos_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Cos".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Sin"
|
||||
#[test]
|
||||
fn test_sin_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Sin".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Neg"
|
||||
#[test]
|
||||
fn test_neg_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Neg".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![-1.0, -2.0], vec![-3.0, -4.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Erf"
|
||||
// #[test]
|
||||
|
||||
// "Tanh"
|
||||
#[test]
|
||||
fn test_tanh_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Tanh".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![0.0, 0.7615942], vec![0.9640276, 0.9950548]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Sigmoid"
|
||||
#[test]
|
||||
fn test_sigmoid_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Sigmoid".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![0.5, 0.7310586], vec![0.880797, 0.95257413]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Gelu"
|
||||
#[test]
|
||||
fn test_gelu_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Gelu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![0.0, 0.8413448], vec![1.9544997, 2.9959502]]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Relu"
|
||||
#[test]
|
||||
fn test_relu_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Relu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(
|
||||
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![0.0, 1.0], vec![0.0, 3.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Constant"
|
||||
// #[test]
|
||||
|
@ -15,7 +15,6 @@ byteorder = { workspace = true }
|
||||
candle = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
candle-nn = { workspace = true }
|
||||
fancy-regex = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
|
@ -1,5 +1,15 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::models::with_tracing::{linear_b as linear, Linear};
|
||||
use crate::models::with_tracing::Linear;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
@ -51,6 +51,14 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
if bias {
|
||||
crate::models::with_tracing::linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
cache: Tensor,
|
||||
|
@ -1,460 +0,0 @@
|
||||
//! EfficientViT (MSRA) inference implementation based on timm.
|
||||
//!
|
||||
//! See "EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention"
|
||||
//! https://arxiv.org/abs/2305.07027
|
||||
|
||||
//! https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_msra.py
|
||||
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, ops::softmax, Conv2dConfig, Func,
|
||||
VarBuilder,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
channels: [usize; 3],
|
||||
blocks: [usize; 3],
|
||||
heads: [usize; 3],
|
||||
kernels: [usize; 4],
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn m0() -> Self {
|
||||
Self {
|
||||
channels: [64, 128, 192],
|
||||
blocks: [1, 2, 3],
|
||||
heads: [4, 4, 4],
|
||||
kernels: [5, 5, 5, 5],
|
||||
}
|
||||
}
|
||||
pub fn m1() -> Self {
|
||||
Self {
|
||||
channels: [128, 144, 192],
|
||||
blocks: [1, 2, 3],
|
||||
heads: [2, 3, 3],
|
||||
kernels: [7, 5, 3, 3],
|
||||
}
|
||||
}
|
||||
pub fn m2() -> Self {
|
||||
Self {
|
||||
channels: [128, 192, 224],
|
||||
blocks: [1, 2, 3],
|
||||
heads: [4, 3, 2],
|
||||
kernels: [7, 5, 3, 3],
|
||||
}
|
||||
}
|
||||
pub fn m3() -> Self {
|
||||
Self {
|
||||
channels: [128, 240, 320],
|
||||
blocks: [1, 2, 3],
|
||||
heads: [4, 3, 4],
|
||||
kernels: [5, 5, 5, 5],
|
||||
}
|
||||
}
|
||||
pub fn m4() -> Self {
|
||||
Self {
|
||||
channels: [128, 256, 384],
|
||||
blocks: [1, 2, 3],
|
||||
heads: [4, 4, 4],
|
||||
kernels: [7, 5, 3, 3],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn m5() -> Self {
|
||||
Self {
|
||||
channels: [192, 288, 384],
|
||||
blocks: [1, 3, 4],
|
||||
heads: [3, 3, 4],
|
||||
kernels: [7, 5, 3, 3],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn efficientvit_stemblock(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride: 2,
|
||||
padding: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(in_channels, out_channels, 3, conv2d_cfg, vb.pp("conv"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&conv)?.apply_t(&bn, false)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn efficientvit_stem(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let conv1 = efficientvit_stemblock(3, dim / 8, vb.pp("conv1"))?;
|
||||
let conv2 = efficientvit_stemblock(dim / 8, dim / 4, vb.pp("conv2"))?;
|
||||
let conv3 = efficientvit_stemblock(dim / 4, dim / 2, vb.pp("conv3"))?;
|
||||
let conv4 = efficientvit_stemblock(dim / 2, dim, vb.pp("conv4"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&conv1)?
|
||||
.relu()?
|
||||
.apply(&conv2)?
|
||||
.relu()?
|
||||
.apply(&conv3)?
|
||||
.relu()?
|
||||
.apply(&conv4)?;
|
||||
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn depthwise_conv(
|
||||
channels: usize,
|
||||
kernel: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
stride,
|
||||
padding,
|
||||
groups: channels,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bn = batch_norm(channels, 1e-5, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(channels, channels, kernel, conv2d_cfg, vb.pp("conv"))?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
|
||||
}
|
||||
|
||||
fn pointwise_conv(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let bn = batch_norm(out_channels, 1e-5, vb.pp("bn"))?;
|
||||
let conv = conv2d_no_bias(in_channels, out_channels, 1, conv2d_cfg, vb.pp("conv"))?;
|
||||
|
||||
Ok(Func::new(move |xs| xs.apply(&conv)?.apply_t(&bn, false)))
|
||||
}
|
||||
|
||||
fn conv_mlp(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let pw1 = pointwise_conv(in_channels, out_channels, vb.pp("pw1"))?;
|
||||
let pw2 = pointwise_conv(out_channels, in_channels, vb.pp("pw2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&pw1)?.relu()?.apply(&pw2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Fixed per-stage resolutions
|
||||
const RESOLUTIONS: [usize; 3] = [14, 7, 4];
|
||||
|
||||
// Attention block
|
||||
fn efficientvit_attn(
|
||||
cfg: &Config,
|
||||
stage: usize,
|
||||
in_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let cga = cascaded_group_attn(cfg, stage, in_channels, vb)?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
|
||||
let (b, c, h, w) = xs.dims4()?;
|
||||
let win_res = 7; // Fixed window resolution
|
||||
let pad_b = (win_res - h % win_res) % win_res;
|
||||
let pad_r = (win_res - w % win_res) % win_res;
|
||||
let ph = h + pad_b;
|
||||
let pw = w + pad_r;
|
||||
let nh = ph / win_res;
|
||||
let nw = pw / win_res;
|
||||
|
||||
if RESOLUTIONS[stage] > win_res {
|
||||
xs = xs.permute((0, 2, 3, 1))?;
|
||||
xs = xs.pad_with_zeros(D::Minus1, 0, pad_r)?;
|
||||
xs = xs.pad_with_zeros(D::Minus2, 0, pad_b)?;
|
||||
xs = xs
|
||||
.reshape((b, nh, win_res, nw, win_res, c))?
|
||||
.transpose(2, 3)?;
|
||||
xs = xs
|
||||
.reshape((b * nh * nw, win_res, win_res, c))?
|
||||
.permute((0, 3, 1, 2))?;
|
||||
}
|
||||
|
||||
xs = xs.apply(&cga)?;
|
||||
|
||||
if RESOLUTIONS[stage] > win_res {
|
||||
xs = xs
|
||||
.permute((0, 2, 3, 1))?
|
||||
.reshape((b, nh, nw, win_res, win_res, c))?;
|
||||
xs = xs.transpose(2, 3)?.reshape((b, ph, pw, c))?;
|
||||
xs = xs.permute((0, 3, 1, 2))?;
|
||||
}
|
||||
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Cascaded group attention
|
||||
fn cascaded_group_attn(
|
||||
cfg: &Config,
|
||||
stage: usize,
|
||||
in_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let heads = cfg.heads[stage];
|
||||
let key_dim = 16;
|
||||
|
||||
let val_dim = in_channels / heads;
|
||||
|
||||
let scale = (key_dim as f64).powf(-0.5);
|
||||
|
||||
let mut dws = Vec::with_capacity(heads);
|
||||
let mut qkvs = Vec::with_capacity(heads);
|
||||
for i in 0..heads {
|
||||
dws.push(depthwise_conv(
|
||||
key_dim,
|
||||
cfg.kernels[i],
|
||||
1,
|
||||
cfg.kernels[i] / 2,
|
||||
vb.pp(format!("dws.{i}")),
|
||||
)?);
|
||||
|
||||
qkvs.push(pointwise_conv(
|
||||
in_channels / heads,
|
||||
in_channels / heads + 2 * key_dim,
|
||||
vb.pp(format!("qkvs.{i}")),
|
||||
)?);
|
||||
}
|
||||
let proj = pointwise_conv(in_channels, in_channels, vb.pp("proj.1"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let (b, _, h, w) = xs.dims4()?;
|
||||
let feats_in = xs.chunk(heads, 1)?;
|
||||
let mut feats_out = Vec::with_capacity(heads);
|
||||
let mut feat = feats_in[0].clone();
|
||||
|
||||
for i in 0..heads {
|
||||
if i > 0 {
|
||||
feat = (&feat + &feats_in[i])?;
|
||||
}
|
||||
feat = feat.apply(&qkvs[i])?;
|
||||
let res = feat.reshape((b, (), h, w))?;
|
||||
let q = res.narrow(1, 0, key_dim)?;
|
||||
let k = res.narrow(1, key_dim, key_dim)?;
|
||||
let v = res.narrow(1, 2 * key_dim, val_dim)?;
|
||||
|
||||
let q = q.apply(&dws[i])?;
|
||||
|
||||
let q = q.flatten_from(2)?;
|
||||
let k = k.flatten_from(2)?;
|
||||
let v = v.flatten_from(2)?;
|
||||
let q = (q * scale)?;
|
||||
|
||||
let att = q.transpose(D::Minus2, D::Minus1)?.matmul(&k)?;
|
||||
let att = softmax(&att, D::Minus1)?;
|
||||
feat = v.matmul(&att.transpose(D::Minus2, D::Minus1)?)?;
|
||||
feat = feat.reshape((b, val_dim, h, w))?;
|
||||
feats_out.push(feat.clone());
|
||||
}
|
||||
|
||||
let xs = Tensor::cat(&feats_out, 1)?;
|
||||
let xs = xs.relu()?.apply(&proj)?;
|
||||
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Used by the downsampling layer
|
||||
fn squeeze_and_excitation(
|
||||
in_channels: usize,
|
||||
squeeze_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let conv2d_cfg = Conv2dConfig {
|
||||
..Default::default()
|
||||
};
|
||||
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
|
||||
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let residual = xs;
|
||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
||||
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
|
||||
|
||||
residual.broadcast_mul(&xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Used by the downsampling layer
|
||||
fn patchmerge(in_channels: usize, out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let dim = in_channels;
|
||||
let hid_dim = in_channels * 4;
|
||||
let conv1 = pointwise_conv(dim, hid_dim, vb.pp("conv1"))?;
|
||||
let conv2 = depthwise_conv(hid_dim, 3, 2, 1, vb.pp("conv2"))?;
|
||||
let conv3 = pointwise_conv(hid_dim, out_channels, vb.pp("conv3"))?;
|
||||
let se = squeeze_and_excitation(hid_dim, hid_dim / 4, vb.pp("se"))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&conv1)?
|
||||
.relu()?
|
||||
.apply(&conv2)?
|
||||
.relu()?
|
||||
.apply(&se)?
|
||||
.apply(&conv3)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Used by the downsampling layer
|
||||
fn res(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let dw = depthwise_conv(dim, 3, 1, 1, vb.pp("0.m"))?;
|
||||
let mlp = conv_mlp(dim, dim * 2, vb.pp("1.m"))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
xs = (&xs + &xs.apply(&dw)?)?;
|
||||
xs = (&xs + &xs.apply(&mlp)?)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Downsampling
|
||||
fn efficientvit_downsample(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let res1 = res(in_channels, vb.pp("res1"))?;
|
||||
let res2 = res(out_channels, vb.pp("res2"))?;
|
||||
let patchmerge = patchmerge(in_channels, out_channels, vb.pp("patchmerge"))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs.apply(&res1)?.apply(&patchmerge)?.apply(&res2)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
fn efficientvit_block(
|
||||
cfg: &Config,
|
||||
stage: usize,
|
||||
dim: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let dw0 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw0.m"))?;
|
||||
let dw1 = depthwise_conv(dim, 3, 1, 1, vb.pp("dw1.m"))?;
|
||||
let ffn0 = conv_mlp(dim, dim * 2, vb.pp("ffn0.m"))?;
|
||||
let ffn1 = conv_mlp(dim, dim * 2, vb.pp("ffn1.m"))?;
|
||||
let attn = efficientvit_attn(cfg, stage, dim, vb.pp("mixer.m.attn"))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
xs = (&xs + &xs.apply(&dw0)?)?;
|
||||
xs = (&xs + &xs.apply(&ffn0)?)?;
|
||||
xs = (&xs + &xs.apply(&attn)?)?;
|
||||
xs = (&xs + &xs.apply(&dw1)?)?;
|
||||
xs = (&xs + &xs.apply(&ffn1)?)?;
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Each stage is made of blocks. There is a downsampling layer between stages.
|
||||
fn efficientvit_stage(cfg: &Config, stage: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let nblocks = cfg.blocks[stage];
|
||||
let mut blocks = Vec::with_capacity(nblocks + 1);
|
||||
|
||||
let in_channels = if stage > 0 {
|
||||
cfg.channels[stage - 1]
|
||||
} else {
|
||||
cfg.channels[0]
|
||||
};
|
||||
let out_channels = cfg.channels[stage];
|
||||
|
||||
if stage > 0 {
|
||||
blocks.push(efficientvit_downsample(
|
||||
in_channels,
|
||||
out_channels,
|
||||
vb.pp("downsample"),
|
||||
)?);
|
||||
}
|
||||
|
||||
for i in 0..nblocks {
|
||||
blocks.push(efficientvit_block(
|
||||
cfg,
|
||||
stage,
|
||||
out_channels,
|
||||
vb.pp(format!("blocks.{i}")),
|
||||
)?);
|
||||
}
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let mut xs = xs.clone();
|
||||
for block in blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
Ok(xs)
|
||||
}))
|
||||
}
|
||||
|
||||
// Classification head.
|
||||
fn efficientvit_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let norm = batch_norm(outputs, 1e-6, vb.pp("bn"))?;
|
||||
let linear = linear(outputs, nclasses, vb.pp("linear"))?;
|
||||
Ok(Func::new(move |xs| {
|
||||
xs.apply_t(&norm, false)?.apply(&linear)
|
||||
}))
|
||||
}
|
||||
|
||||
// Build a efficientvit model for a given configuration.
|
||||
fn efficientvit_model(
|
||||
config: &Config,
|
||||
nclasses: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Func<'static>> {
|
||||
let cls = match nclasses {
|
||||
None => None,
|
||||
Some(nclasses) => {
|
||||
let outputs = config.channels[2];
|
||||
let head = efficientvit_head(outputs, nclasses, vb.pp("head"))?;
|
||||
Some(head)
|
||||
}
|
||||
};
|
||||
|
||||
let stem_dim = config.channels[0];
|
||||
let stem = efficientvit_stem(stem_dim, vb.pp("patch_embed"))?;
|
||||
|
||||
let vb = vb.pp("stages");
|
||||
let stage1 = efficientvit_stage(config, 0, vb.pp(0))?;
|
||||
let stage2 = efficientvit_stage(config, 1, vb.pp(1))?;
|
||||
let stage3 = efficientvit_stage(config, 2, vb.pp(2))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
let xs = xs
|
||||
.apply(&stem)?
|
||||
.apply(&stage1)?
|
||||
.apply(&stage2)?
|
||||
.apply(&stage3)?
|
||||
.mean(D::Minus2)?
|
||||
.mean(D::Minus1)?;
|
||||
match &cls {
|
||||
None => Ok(xs),
|
||||
Some(cls) => xs.apply(cls),
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn efficientvit(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
efficientvit_model(cfg, Some(nclasses), vb)
|
||||
}
|
||||
|
||||
pub fn efficientvit_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
efficientvit_model(cfg, None, vb)
|
||||
}
|
@ -1,773 +0,0 @@
|
||||
#![allow(unused)]
|
||||
use candle::{DType, IndexOp, Layout, Module, Result, Shape, Tensor, D};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, ConvTranspose1d, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
|
||||
pub enum NormType {
|
||||
WeightNorm,
|
||||
TimeGroupNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize)]
|
||||
pub enum PadMode {
|
||||
Constant,
|
||||
Reflect,
|
||||
Replicate,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub target_bandwidths: Vec<f64>,
|
||||
pub sampling_rate: usize,
|
||||
pub audio_channels: usize,
|
||||
pub normalize: bool,
|
||||
pub chunk_length_s: Option<usize>,
|
||||
pub overlap: Option<usize>,
|
||||
pub hidden_size: usize,
|
||||
pub num_filters: usize,
|
||||
pub num_residual_layers: usize,
|
||||
pub upsampling_ratios: Vec<usize>,
|
||||
pub norm_type: NormType,
|
||||
pub kernel_size: usize,
|
||||
pub last_kernel_size: usize,
|
||||
pub residual_kernel_size: usize,
|
||||
pub dilation_growth_rate: usize,
|
||||
pub use_causal_conv: bool,
|
||||
pub pad_mode: PadMode,
|
||||
pub compress: usize,
|
||||
pub num_lstm_layers: usize,
|
||||
pub trim_right_ratio: f64,
|
||||
pub codebook_size: usize,
|
||||
pub codebook_dim: Option<usize>,
|
||||
pub use_conv_shortcut: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
|
||||
sampling_rate: 24_000,
|
||||
audio_channels: 1,
|
||||
normalize: false,
|
||||
chunk_length_s: None,
|
||||
overlap: None,
|
||||
hidden_size: 128,
|
||||
num_filters: 32,
|
||||
num_residual_layers: 1,
|
||||
upsampling_ratios: vec![8, 5, 4, 2],
|
||||
norm_type: NormType::WeightNorm,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
residual_kernel_size: 3,
|
||||
dilation_growth_rate: 2,
|
||||
use_causal_conv: true,
|
||||
// This should be PadMode::Reflect which is currently unsupported in candle.
|
||||
pad_mode: PadMode::Replicate,
|
||||
compress: 2,
|
||||
num_lstm_layers: 2,
|
||||
trim_right_ratio: 1.0,
|
||||
codebook_size: 1024,
|
||||
codebook_dim: None,
|
||||
use_conv_shortcut: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn codebook_dim(&self) -> usize {
|
||||
self.codebook_dim.unwrap_or(self.hidden_size)
|
||||
}
|
||||
|
||||
fn frame_rate(&self) -> usize {
|
||||
let hop_length: usize = self.upsampling_ratios.iter().product();
|
||||
(self.sampling_rate + hop_length - 1) / hop_length
|
||||
}
|
||||
|
||||
fn num_quantizers(&self) -> usize {
|
||||
let num = 1000f64
|
||||
* self
|
||||
.target_bandwidths
|
||||
.last()
|
||||
.expect("empty target_bandwidths");
|
||||
(num as usize) / (self.frame_rate() * 10)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_extra_padding_for_conv1d(
|
||||
xs: &Tensor,
|
||||
k_size: usize,
|
||||
stride: usize,
|
||||
padding_total: usize,
|
||||
) -> Result<usize> {
|
||||
let len = xs.dim(D::Minus1)?;
|
||||
let n_frames = (len + padding_total).saturating_sub(k_size) as f64 / stride as f64 + 1.0;
|
||||
let ideal_len =
|
||||
((n_frames.ceil() as usize - 1) * stride + k_size).saturating_sub(padding_total);
|
||||
Ok(ideal_len.saturating_sub(len))
|
||||
}
|
||||
|
||||
fn pad1d(xs: &Tensor, pad_l: usize, pad_r: usize, mode: PadMode) -> Result<Tensor> {
|
||||
match mode {
|
||||
PadMode::Constant => xs.pad_with_zeros(D::Minus1, pad_l, pad_r),
|
||||
PadMode::Reflect => candle::bail!("pad-mode 'reflect' is not supported"),
|
||||
PadMode::Replicate => xs.pad_with_same(D::Minus1, pad_l, pad_r),
|
||||
}
|
||||
}
|
||||
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
pub fn conv1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: candle_nn::Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
fn conv_transpose1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
bias: bool,
|
||||
config: candle_nn::ConvTranspose1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<ConvTranspose1d> {
|
||||
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((in_c, out_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(out_c, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(ConvTranspose1d::new(weight, bias, config))
|
||||
}
|
||||
|
||||
struct CodebookEncode;
|
||||
|
||||
impl candle::CustomOp2 for CodebookEncode {
|
||||
fn name(&self) -> &'static str {
|
||||
"cb"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
lhs_storage: &candle::CpuStorage,
|
||||
lhs_layout: &Layout,
|
||||
rhs_storage: &candle::CpuStorage,
|
||||
rhs_layout: &Layout,
|
||||
) -> Result<(candle::CpuStorage, Shape)> {
|
||||
use rayon::prelude::*;
|
||||
|
||||
let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
|
||||
let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
|
||||
if lhs_dim2 != rhs_dim2 {
|
||||
candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
|
||||
}
|
||||
if lhs_dim2 == 0 {
|
||||
candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
|
||||
}
|
||||
let lhs = match lhs_layout.contiguous_offsets() {
|
||||
None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
|
||||
Some((o1, o2)) => {
|
||||
let slice = lhs_storage.as_slice::<f32>()?;
|
||||
&slice[o1..o2]
|
||||
}
|
||||
};
|
||||
let rhs = match rhs_layout.contiguous_offsets() {
|
||||
None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
|
||||
Some((o1, o2)) => {
|
||||
let slice = rhs_storage.as_slice::<f32>()?;
|
||||
&slice[o1..o2]
|
||||
}
|
||||
};
|
||||
let dst = (0..lhs_dim1)
|
||||
.into_par_iter()
|
||||
.map(|idx1| {
|
||||
let mut where_min = 0;
|
||||
let mut min_dist = f32::INFINITY;
|
||||
let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
|
||||
for idx2 in 0..rhs_dim1 {
|
||||
let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
|
||||
let mut dist = 0f32;
|
||||
for (a, b) in lhs.iter().zip(rhs.iter()) {
|
||||
dist += (a - b) * (a - b)
|
||||
}
|
||||
if dist < min_dist {
|
||||
min_dist = dist;
|
||||
where_min = idx2;
|
||||
}
|
||||
}
|
||||
where_min as u32
|
||||
})
|
||||
.collect();
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (lhs_dim1,).into()))
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EuclideanCodebook {
|
||||
inited: Tensor,
|
||||
cluster_size: Tensor,
|
||||
embed: candle_nn::Embedding,
|
||||
embed_avg: Tensor,
|
||||
c2: Tensor,
|
||||
}
|
||||
|
||||
impl EuclideanCodebook {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let inited = vb.get(1, "inited")?;
|
||||
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||
let embed = vb.get(e_shape, "embed")?;
|
||||
let c2 = ((&embed * &embed)?.sum(D::Minus1)? / 2.0)?;
|
||||
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||
Ok(Self {
|
||||
inited,
|
||||
cluster_size,
|
||||
embed: candle_nn::Embedding::new(embed, cfg.codebook_dim()),
|
||||
embed_avg,
|
||||
c2,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
let dot_prod = xs.matmul(&self.embed.embeddings().t()?)?;
|
||||
let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut target_shape = xs.dims().to_vec();
|
||||
target_shape.pop();
|
||||
let xs = xs.flatten_to(D::Minus2)?;
|
||||
let _ = xs.dims2()?;
|
||||
let codes = Tensor::apply_op2(&xs, self.embed.embeddings(), CodebookEncode)?;
|
||||
codes.reshape(target_shape)
|
||||
}
|
||||
|
||||
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.embed.forward(embed_ind)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VectorQuantization {
|
||||
codebook: EuclideanCodebook,
|
||||
}
|
||||
|
||||
impl VectorQuantization {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let codebook = EuclideanCodebook::new(cfg, vb.pp("codebook"))?;
|
||||
Ok(Self { codebook })
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.transpose(1, 2)?;
|
||||
self.codebook.encode_slow(&xs)
|
||||
}
|
||||
|
||||
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.codebook.decode(embed_ind)?;
|
||||
let quantize = quantize.transpose(1, 2)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResidualVectorQuantizer {
|
||||
layers: Vec<VectorQuantization>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl ResidualVectorQuantizer {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = &vb.pp("layers");
|
||||
let layers = (0..cfg.num_quantizers())
|
||||
.map(|i| VectorQuantization::new(cfg, vb.pp(i)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self {
|
||||
layers,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut codes = Vec::with_capacity(self.layers.len());
|
||||
let mut residual = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let indices = layer.encode(&residual)?;
|
||||
let quantized = layer.decode(&indices)?;
|
||||
residual = (residual - quantized)?;
|
||||
codes.push(indices)
|
||||
}
|
||||
Tensor::stack(&codes, 0)
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), self.dtype, codes.device())?;
|
||||
let ncodes = codes.dim(0)?;
|
||||
if ncodes > self.layers.len() {
|
||||
candle::bail!(
|
||||
"codes shape {:?} does not match the number of quantization layers {}",
|
||||
codes.shape(),
|
||||
self.layers.len()
|
||||
)
|
||||
}
|
||||
for (i, layer) in self.layers.iter().take(ncodes).enumerate() {
|
||||
let quantized = layer.decode(&codes.i(i)?)?;
|
||||
quantized_out = quantized.broadcast_add(&quantized_out)?;
|
||||
}
|
||||
Ok(quantized_out)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EncodecLSTM {
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
pub fn new(dim: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
// This is different from the Python transformers version as candle LSTM is batch first.
|
||||
let xs = xs.t()?;
|
||||
let residual = &xs;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
let xs = (xs + residual)?.t()?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EncodecConvTranspose1d {
|
||||
conv: ConvTranspose1d,
|
||||
}
|
||||
|
||||
impl EncodecConvTranspose1d {
|
||||
fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k: usize,
|
||||
stride: usize,
|
||||
_cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let cfg = candle_nn::ConvTranspose1dConfig {
|
||||
stride,
|
||||
..Default::default()
|
||||
};
|
||||
let conv = conv_transpose1d_weight_norm(in_c, out_c, k, true, cfg, vb.pp("conv"))?;
|
||||
Ok(Self { conv })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EncodecConv1d {
|
||||
causal: bool,
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
pad_mode: PadMode,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
pub fn new(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let conv = match cfg.norm_type {
|
||||
NormType::WeightNorm => conv1d_weight_norm(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
candle_nn::Conv1dConfig {
|
||||
stride,
|
||||
dilation,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
candle_nn::Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
let norm = match cfg.norm_type {
|
||||
NormType::None | NormType::WeightNorm => None,
|
||||
NormType::TimeGroupNorm => {
|
||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(gn)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
causal: cfg.use_causal_conv,
|
||||
conv,
|
||||
norm,
|
||||
pad_mode: cfg.pad_mode,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b, _t, _c) = xs.dims3()?;
|
||||
let k_size = self.conv.weight().dim(D::Minus1)?;
|
||||
let conv_cfg = self.conv.config();
|
||||
// Effective kernel size with dilations.
|
||||
let k_size = (k_size - 1) * conv_cfg.dilation + 1;
|
||||
let padding_total = k_size - conv_cfg.stride;
|
||||
let extra_padding =
|
||||
get_extra_padding_for_conv1d(xs, k_size, conv_cfg.stride, padding_total)?;
|
||||
let xs = if self.causal {
|
||||
pad1d(xs, padding_total, extra_padding, self.pad_mode)?
|
||||
} else {
|
||||
let padding_right = padding_total / 2;
|
||||
let padding_left = padding_total - padding_right;
|
||||
pad1d(
|
||||
xs,
|
||||
padding_left,
|
||||
padding_right + extra_padding,
|
||||
self.pad_mode,
|
||||
)?
|
||||
};
|
||||
let xs = self.conv.forward(&xs)?;
|
||||
match &self.norm {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EncodecResnetBlock {
|
||||
block_conv1: EncodecConv1d,
|
||||
block_conv2: EncodecConv1d,
|
||||
shortcut: Option<EncodecConv1d>,
|
||||
}
|
||||
|
||||
impl EncodecResnetBlock {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
(dilation1, dilation2): (usize, usize),
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new(vb.pp("block"));
|
||||
// TODO: Apply dilations!
|
||||
layer.inc();
|
||||
let block_conv1 = EncodecConv1d::new(
|
||||
dim,
|
||||
h,
|
||||
cfg.residual_kernel_size,
|
||||
1,
|
||||
dilation1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
layer.inc();
|
||||
let block_conv2 = EncodecConv1d::new(h, dim, 1, 1, dilation2, cfg, layer.next())?;
|
||||
let shortcut = if cfg.use_conv_shortcut {
|
||||
let conv = EncodecConv1d::new(dim, dim, 1, 1, 1, cfg, vb.pp("shortcut"))?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
block_conv1,
|
||||
block_conv2,
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv1.forward(&xs)?;
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv2.forward(&xs)?;
|
||||
let xs = match &self.shortcut {
|
||||
None => (xs + residual)?,
|
||||
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
|
||||
};
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Layer<'a> {
|
||||
vb: VarBuilder<'a>,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl<'a> Layer<'a> {
|
||||
fn new(vb: VarBuilder<'a>) -> Self {
|
||||
Self { vb, cnt: 0 }
|
||||
}
|
||||
|
||||
fn inc(&mut self) {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next(&mut self) -> VarBuilder {
|
||||
let vb = self.vb.pp(&self.cnt.to_string());
|
||||
self.cnt += 1;
|
||||
vb
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Encoder {
|
||||
init_conv: EncodecConv1d,
|
||||
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
|
||||
final_lstm: EncodecLSTM,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let init_conv = EncodecConv1d::new(
|
||||
cfg.audio_channels,
|
||||
cfg.num_filters,
|
||||
cfg.kernel_size,
|
||||
1,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
let mut sampling_layers = vec![];
|
||||
let mut scaling = 1;
|
||||
for &ratio in cfg.upsampling_ratios.iter().rev() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::new(
|
||||
current_scale,
|
||||
(cfg.dilation_growth_rate.pow(j), 1),
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConv1d::new(
|
||||
current_scale,
|
||||
current_scale * 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
sampling_layers.push((resnets, conv1d));
|
||||
scaling *= 2;
|
||||
}
|
||||
let final_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::new(
|
||||
cfg.num_filters * scaling,
|
||||
cfg.hidden_size,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
final_lstm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Encoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Decoder {
|
||||
init_conv: EncodecConv1d,
|
||||
init_lstm: EncodecLSTM,
|
||||
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||
let init_conv = EncodecConv1d::new(
|
||||
cfg.hidden_size,
|
||||
cfg.num_filters * scaling,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
let init_lstm = EncodecLSTM::new(cfg.num_filters * scaling, cfg, layer.next())?;
|
||||
let mut sampling_layers = vec![];
|
||||
for &ratio in cfg.upsampling_ratios.iter() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConvTranspose1d::new(
|
||||
current_scale,
|
||||
current_scale / 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::new(
|
||||
current_scale / 2,
|
||||
(cfg.dilation_growth_rate.pow(j), 1),
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
sampling_layers.push((conv1d, resnets));
|
||||
scaling /= 2;
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::new(
|
||||
cfg.num_filters,
|
||||
cfg.audio_channels,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
1,
|
||||
cfg,
|
||||
layer.next(),
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
init_lstm,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Decoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Model {
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
quantizer: ResidualVectorQuantizer,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
|
||||
let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
|
||||
let quantizer = ResidualVectorQuantizer::new(cfg, vb.pp("quantizer"))?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quantizer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.encoder.forward(xs)?;
|
||||
let codes = self.quantizer.encode(&xs)?;
|
||||
codes.transpose(0, 1)
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let (_b_sz, _codebooks, _seqlen) = codes.dims3()?;
|
||||
let codes = codes.transpose(0, 1)?;
|
||||
let embeddings = self.quantizer.decode(&codes)?;
|
||||
let outputs = self.decoder.forward(&embeddings)?;
|
||||
Ok(outputs)
|
||||
}
|
||||
}
|
@ -1,8 +1,18 @@
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
|
@ -1,381 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b as linear, Linear, VarBuilder};
|
||||
|
||||
fn default_max_position_embeddings() -> usize {
|
||||
4096
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub attention_bias: bool,
|
||||
pub head_dim: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub vocab_size: usize,
|
||||
|
||||
#[serde(default = "default_max_position_embeddings")]
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
weight: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(dim, "weight")?;
|
||||
Ok(Self { weight, eps })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&(&self.weight + 1.0)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
|
||||
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = cfg.head_dim;
|
||||
let bias = cfg.attention_bias;
|
||||
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, ()))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
hidden_size: cfg.hidden_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = self.embed_tokens.forward(input_ids)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
@ -2,6 +2,7 @@ use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
@ -83,9 +84,10 @@ impl Config {
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
masks: HashMap<usize, Tensor>,
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
kvs: Vec<Option<(Tensor, Tensor)>>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
@ -110,24 +112,25 @@ impl Cache {
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
masks: HashMap::new(),
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
use_kv_cache,
|
||||
kvs: vec![None; config.num_hidden_layers],
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
|
||||
device: device.clone(),
|
||||
cos,
|
||||
sin,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
@ -161,6 +164,7 @@ struct CausalSelfAttention {
|
||||
num_attention_heads: usize,
|
||||
num_key_value_heads: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
@ -183,11 +187,11 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
|
||||
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
@ -197,13 +201,7 @@ impl CausalSelfAttention {
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
@ -220,11 +218,12 @@ impl CausalSelfAttention {
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if cache.use_kv_cache {
|
||||
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||
let k_seq_len = k.dims()[1];
|
||||
@ -240,7 +239,7 @@ impl CausalSelfAttention {
|
||||
.contiguous()?
|
||||
}
|
||||
}
|
||||
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
@ -259,7 +258,7 @@ impl CausalSelfAttention {
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
@ -284,7 +283,7 @@ impl CausalSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let size_in = cfg.hidden_size;
|
||||
@ -302,6 +301,7 @@ impl CausalSelfAttention {
|
||||
num_attention_heads: cfg.num_attention_heads,
|
||||
num_key_value_heads: cfg.num_key_value_heads,
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
cache: cache.clone(),
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
span,
|
||||
span_rot,
|
||||
@ -357,25 +357,19 @@ struct Block {
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "block");
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let rms_2 = RmsNorm::load(
|
||||
@ -402,11 +396,11 @@ pub struct Llama {
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
@ -414,12 +408,12 @@ impl Llama {
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg).unwrap())
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
|
@ -2,6 +2,7 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::linear_no_bias as linear;
|
||||
use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
@ -69,11 +70,12 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: HashMap<usize, Tensor>,
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
pub kvs: Vec<Option<(Tensor, Tensor)>>,
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
pub cos: Tensor,
|
||||
pub sin: Tensor,
|
||||
device: Device,
|
||||
@ -103,24 +105,25 @@ impl Cache {
|
||||
let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||
let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?;
|
||||
Ok(Self {
|
||||
masks: HashMap::new(),
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
use_kv_cache,
|
||||
kvs: vec![None; cfg.n_layers],
|
||||
kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])),
|
||||
cos,
|
||||
sin,
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
pub fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
@ -130,7 +133,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -139,13 +141,14 @@ struct CausalSelfAttention {
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = cos.unsqueeze(1)?;
|
||||
let sin = sin.unsqueeze(1)?;
|
||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
@ -159,13 +162,7 @@ impl CausalSelfAttention {
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
@ -175,15 +172,16 @@ impl CausalSelfAttention {
|
||||
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if cache.use_kv_cache {
|
||||
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||
}
|
||||
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
@ -194,7 +192,7 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
@ -218,7 +216,7 @@ impl CausalSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let size_in = cfg.dim;
|
||||
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
||||
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
||||
@ -234,6 +232,7 @@ impl CausalSelfAttention {
|
||||
n_head: cfg.n_heads,
|
||||
n_key_value_head: cfg.n_kv_heads,
|
||||
head_dim: cfg.dim / cfg.n_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -245,7 +244,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
@ -276,7 +274,6 @@ impl Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
@ -294,23 +291,17 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
@ -324,7 +315,6 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
@ -334,23 +324,23 @@ pub struct Llama {
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap())
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
wte,
|
||||
|
@ -32,9 +32,9 @@ impl Config {
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
pub hs: Vec<Tensor>,
|
||||
pub prev_xs: Vec<[Tensor; D_CONV]>,
|
||||
pub pos: usize,
|
||||
hs: Vec<Tensor>,
|
||||
prev_xs: Vec<[Tensor; D_CONV]>,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl State {
|
||||
|
@ -1,968 +0,0 @@
|
||||
use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
|
||||
let img = img.unsqueeze(dim + 1)?;
|
||||
let mut dims = img.dims().to_vec();
|
||||
dims[dim + 1] = repeats;
|
||||
img.broadcast_as(dims)?.flatten(dim, dim + 1)
|
||||
}
|
||||
pub mod speaker_encoder {
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub sampling_rate: usize,
|
||||
pub partial_n_frames: usize,
|
||||
pub model_hidden_size: usize,
|
||||
pub model_embedding_size: usize,
|
||||
pub model_num_layers: usize,
|
||||
pub mel_window_length: usize,
|
||||
pub mel_window_step: usize,
|
||||
pub mel_n_channels: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn cfg() -> Self {
|
||||
Self {
|
||||
sampling_rate: 16_000,
|
||||
partial_n_frames: 160,
|
||||
model_hidden_size: 256,
|
||||
model_embedding_size: 256,
|
||||
model_num_layers: 3,
|
||||
mel_window_length: 25,
|
||||
mel_window_step: 10,
|
||||
mel_n_channels: 40,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Model {
|
||||
lstms: Vec<candle_nn::LSTM>,
|
||||
linear: Linear,
|
||||
cfg: Config,
|
||||
}
|
||||
|
||||
type Slice = (usize, usize);
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mut lstms = Vec::with_capacity(cfg.model_num_layers);
|
||||
let vb_l = vb.pp("lstm");
|
||||
for layer_idx in 0..cfg.model_num_layers {
|
||||
let c = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let in_c = if layer_idx == 0 {
|
||||
cfg.mel_n_channels
|
||||
} else {
|
||||
cfg.model_hidden_size
|
||||
};
|
||||
let lstm = candle_nn::lstm(in_c, cfg.model_hidden_size, c, vb_l.clone())?;
|
||||
lstms.push(lstm)
|
||||
}
|
||||
let linear = linear_b(
|
||||
cfg.model_hidden_size,
|
||||
cfg.model_embedding_size,
|
||||
true,
|
||||
vb.pp("linear"),
|
||||
)?;
|
||||
Ok(Self { lstms, linear, cfg })
|
||||
}
|
||||
|
||||
fn compute_partial_slices(
|
||||
&self,
|
||||
n_samples: usize,
|
||||
rate: f64,
|
||||
min_coverage: f64,
|
||||
) -> (Vec<Slice>, Vec<Slice>) {
|
||||
let c = &self.cfg;
|
||||
// Compute how many frames separate two partial utterances
|
||||
let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
|
||||
let n_frames = n_samples / samples_per_frame + 1;
|
||||
let frame_step =
|
||||
(c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
|
||||
let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
|
||||
// Compute the slices.
|
||||
let mut wav_slices = vec![];
|
||||
let mut mel_slices = vec![];
|
||||
for i in (0..steps).step_by(frame_step) {
|
||||
let mel_range = (i, i + c.partial_n_frames);
|
||||
let wav_range = (
|
||||
i * samples_per_frame,
|
||||
(i + c.partial_n_frames) * samples_per_frame,
|
||||
);
|
||||
mel_slices.push(mel_range);
|
||||
wav_slices.push(wav_range);
|
||||
}
|
||||
// Evaluate whether extra padding is warranted or not.
|
||||
let last_wav_range = match wav_slices.last() {
|
||||
None => return (wav_slices, mel_slices),
|
||||
Some(l) => *l,
|
||||
};
|
||||
let coverage = (n_samples - last_wav_range.0) as f64
|
||||
/ (last_wav_range.1 - last_wav_range.0) as f64;
|
||||
if coverage > min_coverage && mel_slices.len() > 1 {
|
||||
mel_slices.pop();
|
||||
wav_slices.pop();
|
||||
}
|
||||
(wav_slices, mel_slices)
|
||||
}
|
||||
|
||||
pub fn embed_utterance(
|
||||
&self,
|
||||
wav: &[f32],
|
||||
mel_filters: &[f32],
|
||||
rate: f64,
|
||||
min_c: f64,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
|
||||
let max_wave_length = match wav_slices.last() {
|
||||
Some(v) => v.1,
|
||||
None => candle::bail!("empty wav slices"),
|
||||
};
|
||||
let wav = if max_wave_length > wav.len() {
|
||||
let mut wav = wav.to_vec();
|
||||
wav.resize(max_wave_length - wav.len(), 0.0);
|
||||
std::borrow::Cow::Owned(wav)
|
||||
} else {
|
||||
std::borrow::Cow::Borrowed(wav)
|
||||
};
|
||||
let mel = crate::models::whisper::audio::log_mel_spectrogram_(
|
||||
wav.as_ref(),
|
||||
mel_filters,
|
||||
/* fft_size */ self.cfg.mel_window_length,
|
||||
/* fft_step */ self.cfg.mel_window_step,
|
||||
self.cfg.mel_n_channels,
|
||||
false,
|
||||
);
|
||||
let mels = mel_slices
|
||||
.iter()
|
||||
.flat_map(|s| [mel[s.0], mel[s.1]])
|
||||
.collect::<Vec<_>>();
|
||||
let mels = Tensor::from_vec(mels, (1, mel_slices.len(), 2), device)?;
|
||||
let partial_embeds = self.forward(&mels)?;
|
||||
let raw_embed = partial_embeds.mean(0)?;
|
||||
let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
|
||||
raw_embed.broadcast_div(&norm)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Model {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
|
||||
// This is different from the Python transformers version as candle LSTM is batch first.
|
||||
let xs = xs.t()?;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.lstms.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
let xs = xs.t()?;
|
||||
let embeds_raw = xs.apply(&self.linear)?.relu()?;
|
||||
let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
|
||||
embeds_raw.broadcast_div(&norm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Rank = u32;
|
||||
|
||||
pub mod tokenizers {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct BPE {
|
||||
pub re: fancy_regex::Regex,
|
||||
pub end_of_text: usize,
|
||||
pub offset: usize,
|
||||
pub ranks: HashMap<Vec<u8>, Rank>,
|
||||
}
|
||||
|
||||
impl BPE {
|
||||
pub fn from_json(json: &serde_json::Value, end_of_text: usize) -> Result<Self> {
|
||||
let json = match json.as_object() {
|
||||
None => candle::bail!("json value is not an object"),
|
||||
Some(json) => json,
|
||||
};
|
||||
let re = match json.get("pat_str") {
|
||||
None => candle::bail!("json object has no pat_str field"),
|
||||
Some(pat_str) => match pat_str.as_str() {
|
||||
None => candle::bail!("pat_str field is not a string"),
|
||||
Some(pat_str) => fancy_regex::Regex::new(pat_str).map_err(E::wrap)?,
|
||||
},
|
||||
};
|
||||
let offset = match json.get("offset") {
|
||||
None => candle::bail!("json object has no offset field"),
|
||||
Some(offset) => match offset.as_u64() {
|
||||
None => candle::bail!("offset field is not a positive int"),
|
||||
Some(offset) => offset as usize,
|
||||
},
|
||||
};
|
||||
let mut ranks = HashMap::new();
|
||||
for id in 0u8..=255 {
|
||||
ranks.insert(vec![id], id as u32);
|
||||
}
|
||||
let mergeable_ranks = match json.get("mergeable_ranks") {
|
||||
None => candle::bail!("json object has no mergeable_ranks field"),
|
||||
Some(mr) => match mr.as_object() {
|
||||
None => candle::bail!("mergeable_ranks is not an object"),
|
||||
Some(mr) => mr,
|
||||
},
|
||||
};
|
||||
for (key, value) in mergeable_ranks.iter() {
|
||||
let value = match value.as_u64() {
|
||||
None => candle::bail!("mergeable_ranks '{key}' is not a u64"),
|
||||
Some(value) => value as u32,
|
||||
};
|
||||
if value < 256 {
|
||||
continue;
|
||||
}
|
||||
// No escaping for other keys.
|
||||
let key = key.as_bytes().to_vec();
|
||||
ranks.insert(key, value);
|
||||
}
|
||||
Ok(Self {
|
||||
re,
|
||||
end_of_text,
|
||||
offset,
|
||||
ranks,
|
||||
})
|
||||
}
|
||||
|
||||
// Taken from:
|
||||
// https://github.com/openai/tiktoken/blob/1b9faf2779855124f05174adf1383e53689ed94b/src/lib.rs#L16C1-L82C2
|
||||
fn _byte_pair_merge(&self, piece: &[u8]) -> Vec<(usize, Rank)> {
|
||||
// This is a vector of (start, rank).
|
||||
// The rank is of the pair starting at position start.
|
||||
let mut parts = Vec::with_capacity(piece.len() + 1);
|
||||
|
||||
// Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE
|
||||
// the way we currently do, this is equivalent. An easy way to break this would be to decouple
|
||||
// merge priority from token index or to prevent specific token merges.
|
||||
let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
|
||||
for i in 0..piece.len() - 1 {
|
||||
let rank = *self.ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
|
||||
if rank < min_rank.0 {
|
||||
min_rank = (rank, i);
|
||||
}
|
||||
parts.push((i, rank));
|
||||
}
|
||||
parts.push((piece.len() - 1, Rank::MAX));
|
||||
parts.push((piece.len(), Rank::MAX));
|
||||
|
||||
let get_rank = {
|
||||
#[inline(always)]
|
||||
|parts: &Vec<(usize, Rank)>, i: usize| {
|
||||
if (i + 3) < parts.len() {
|
||||
// Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted
|
||||
// parts[i + 1], see comment in the main loop.
|
||||
*self
|
||||
.ranks
|
||||
.get(&piece[parts[i].0..parts[i + 3].0])
|
||||
.unwrap_or(&Rank::MAX)
|
||||
} else {
|
||||
Rank::MAX
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// If you have n parts and m merges, this does O(mn) work.
|
||||
// We could do something with a heap and do O(m log n) work.
|
||||
// n is often very small so considerations like cache-locality outweigh the algorithmic
|
||||
// complexity downsides of the `parts` vector.
|
||||
while min_rank.0 != Rank::MAX {
|
||||
let i = min_rank.1;
|
||||
// Update parts[i] and parts[i - 1] before removing parts[i + 1], since
|
||||
// `parts.remove(i + 1)` will thrash the cache.
|
||||
if i > 0 {
|
||||
parts[i - 1].1 = get_rank(&parts, i - 1);
|
||||
}
|
||||
parts[i].1 = get_rank(&parts, i);
|
||||
parts.remove(i + 1);
|
||||
|
||||
min_rank = (Rank::MAX, usize::MAX);
|
||||
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
|
||||
if rank < min_rank.0 {
|
||||
min_rank = (rank, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
parts
|
||||
}
|
||||
|
||||
pub fn byte_pair_encode(&self, piece: &[u8]) -> Vec<Rank> {
|
||||
if piece.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
if piece.len() == 1 {
|
||||
return vec![self.ranks[piece]];
|
||||
}
|
||||
assert!(piece.len() > 1);
|
||||
self._byte_pair_merge(piece)
|
||||
.windows(2)
|
||||
.map(|part| self.ranks[&piece[part[0].0..part[1].0]])
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
|
||||
let mut bpe_tokens: Vec<u32> = Vec::new();
|
||||
for word in self.re.find_iter(text) {
|
||||
let word = word.map_err(E::wrap)?;
|
||||
let word_tokens = self.byte_pair_encode(word.as_str().as_bytes());
|
||||
for &token in word_tokens.iter() {
|
||||
bpe_tokens.push(token + self.offset as u32)
|
||||
}
|
||||
}
|
||||
bpe_tokens.push((self.end_of_text + self.offset) as u32);
|
||||
Ok(bpe_tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod gpt {
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
|
||||
pub enum NormType {
|
||||
LayerNorm,
|
||||
RMSNorm,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
|
||||
pub enum AttnKernelType {
|
||||
Fa2,
|
||||
TorchAttn,
|
||||
Hand,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
|
||||
pub enum NonLinearityType {
|
||||
Gelu,
|
||||
Swiglu,
|
||||
}
|
||||
|
||||
enum Norm {
|
||||
RMSNorm(candle_nn::RmsNorm),
|
||||
LayerNorm(candle_nn::LayerNorm),
|
||||
}
|
||||
|
||||
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L27
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub block_size: usize,
|
||||
pub vocab_sizes: Vec<usize>,
|
||||
pub target_vocab_sizes: Vec<usize>,
|
||||
pub n_layer: usize,
|
||||
pub n_head: usize,
|
||||
pub n_embd: usize,
|
||||
pub bias: bool,
|
||||
pub causal: bool,
|
||||
pub spk_emb_on_text: bool,
|
||||
pub norm_type: NormType,
|
||||
pub rmsnorm_eps: f64,
|
||||
pub nonlinearity_type: NonLinearityType,
|
||||
pub swiglu_multiple_of: Option<usize>,
|
||||
pub attn_kernel_type: AttnKernelType,
|
||||
pub kv_cache_enabled: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn cfg1b_v0_1() -> Self {
|
||||
Self {
|
||||
n_layer: 6,
|
||||
n_head: 6,
|
||||
n_embd: 384,
|
||||
block_size: 1024,
|
||||
bias: false,
|
||||
vocab_sizes: vec![1538, 1025],
|
||||
causal: false,
|
||||
target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
|
||||
swiglu_multiple_of: Some(256),
|
||||
norm_type: NormType::LayerNorm,
|
||||
kv_cache_enabled: false,
|
||||
attn_kernel_type: AttnKernelType::TorchAttn,
|
||||
spk_emb_on_text: true,
|
||||
nonlinearity_type: NonLinearityType::Gelu,
|
||||
rmsnorm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Norm {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
match cfg.norm_type {
|
||||
NormType::RMSNorm => {
|
||||
let rms_norm = candle_nn::rms_norm(cfg.n_embd, cfg.rmsnorm_eps, vb)?;
|
||||
Ok(Self::RMSNorm(rms_norm))
|
||||
}
|
||||
NormType::LayerNorm => {
|
||||
let ln_cfg = candle_nn::LayerNormConfig {
|
||||
affine: cfg.bias,
|
||||
..Default::default()
|
||||
};
|
||||
let layer_norm = candle_nn::layer_norm(cfg.n_embd, ln_cfg, vb)?;
|
||||
Ok(Self::LayerNorm(layer_norm))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Norm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::RMSNorm(m) => m.forward(xs),
|
||||
Self::LayerNorm(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/attn.py#L18
|
||||
struct SelfAttention {
|
||||
c_attn: Linear,
|
||||
c_proj: Linear,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
// The different attention variants are likely to be identical but still we only accept
|
||||
// TorchAttn for now.
|
||||
if cfg.attn_kernel_type != AttnKernelType::TorchAttn {
|
||||
candle::bail!("only TorchAttn is supported")
|
||||
}
|
||||
if cfg.kv_cache_enabled {
|
||||
candle::bail!("kv_cache_enabled=true is not supported")
|
||||
}
|
||||
let c_attn = linear_b(cfg.n_embd, cfg.n_embd * 3, cfg.bias, vb.pp("c_attn"))?;
|
||||
let c_proj = linear_b(cfg.n_embd, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
|
||||
Ok(Self {
|
||||
c_attn,
|
||||
c_proj,
|
||||
n_head: cfg.n_head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SelfAttention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, t, c) = xs.dims3()?;
|
||||
let c_x = xs
|
||||
.apply(&self.c_attn)?
|
||||
.reshape((b, t, 3, self.n_head, c / self.n_head))?;
|
||||
let q = c_x.i((.., .., 0))?;
|
||||
let k = c_x.i((.., .., 1))?;
|
||||
let v = c_x.i((.., .., 2))?;
|
||||
let q = q.transpose(1, 2)?.contiguous()?;
|
||||
let k = k.transpose(1, 2)?.contiguous()?;
|
||||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
|
||||
// TODO: causal mask
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
let att = att.matmul(&v)?.transpose(1, 2)?;
|
||||
att.reshape((b, t, c))?.apply(&self.c_proj)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/layers.py#L43
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
enum MLP {
|
||||
Gelu {
|
||||
c_fc: Linear,
|
||||
c_proj: Linear,
|
||||
},
|
||||
Swiglu {
|
||||
w1: Linear,
|
||||
w3: Linear,
|
||||
c_proj: Linear,
|
||||
},
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_dim = 4 * cfg.n_embd;
|
||||
let slf = match cfg.nonlinearity_type {
|
||||
NonLinearityType::Gelu => {
|
||||
let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
|
||||
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
|
||||
Self::Gelu { c_fc, c_proj }
|
||||
}
|
||||
NonLinearityType::Swiglu => {
|
||||
let hidden_dim = (2 * hidden_dim) / 3;
|
||||
let swiglu_multiple_of = match cfg.swiglu_multiple_of {
|
||||
None => candle::bail!("swiglu-multiple-of has to be set"),
|
||||
Some(smo) => smo,
|
||||
};
|
||||
let hidden_dim = swiglu_multiple_of * (hidden_dim + swiglu_multiple_of - 1)
|
||||
/ swiglu_multiple_of;
|
||||
let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
|
||||
let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
|
||||
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
|
||||
Self::Swiglu { w1, w3, c_proj }
|
||||
}
|
||||
};
|
||||
Ok(slf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu { c_fc, c_proj } => xs.apply(c_fc)?.gelu()?.apply(c_proj),
|
||||
Self::Swiglu { w1, w3, c_proj } => {
|
||||
let w1 = xs.apply(w1)?;
|
||||
let w3 = xs.apply(w3)?;
|
||||
(w1.silu()? * w3)?.apply(c_proj)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/layers/combined.py#L7
|
||||
struct Block {
|
||||
ln_1: Norm,
|
||||
ln_2: Norm,
|
||||
attn: SelfAttention,
|
||||
mlp: MLP,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let ln_1 = Norm::new(cfg, vb.pp("ln_1"))?;
|
||||
let ln_2 = Norm::new(cfg, vb.pp("ln_2"))?;
|
||||
let attn = SelfAttention::new(cfg, vb.pp("attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
Ok(Block {
|
||||
ln_1,
|
||||
ln_2,
|
||||
attn,
|
||||
mlp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
|
||||
let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/metavoiceio/metavoice-src/blob/11550bb4e8a1ad032cc1556cc924f7a4e767cbfa/fam/llm/model.py#L79
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
pub struct Model {
|
||||
wtes: Vec<candle_nn::Embedding>,
|
||||
wpe: candle_nn::Embedding,
|
||||
h: Vec<Block>,
|
||||
ln_f: Norm,
|
||||
lm_heads: Vec<Linear>,
|
||||
cfg: Config,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_t = vb.pp("transformer");
|
||||
let ln_f = Norm::new(&cfg, vb_t.pp("ln_f"))?;
|
||||
let mut wtes = Vec::with_capacity(cfg.vocab_sizes.len());
|
||||
let vb_w = vb_t.pp("wtes");
|
||||
for (idx, vocab_size) in cfg.vocab_sizes.iter().enumerate() {
|
||||
let wte = candle_nn::embedding(*vocab_size, cfg.n_embd, vb_w.pp(idx))?;
|
||||
wtes.push(wte)
|
||||
}
|
||||
let wpe = candle_nn::embedding(cfg.block_size, cfg.n_embd, vb_t.pp("wpe"))?;
|
||||
|
||||
let mut h = Vec::with_capacity(cfg.n_layer);
|
||||
let vb_h = vb_t.pp("h");
|
||||
for idx in 0..cfg.n_layer {
|
||||
let block = Block::new(&cfg, vb_h.pp(idx))?;
|
||||
h.push(block)
|
||||
}
|
||||
|
||||
let mut lm_heads = Vec::with_capacity(cfg.target_vocab_sizes.len());
|
||||
let vb_l = vb.pp("lm_heads");
|
||||
for (idx, vocab_size) in cfg.target_vocab_sizes.iter().enumerate() {
|
||||
let head = linear_b(cfg.n_embd, *vocab_size, false, vb_l.pp(idx))?;
|
||||
lm_heads.push(head)
|
||||
}
|
||||
Ok(Self {
|
||||
wtes,
|
||||
wpe,
|
||||
h,
|
||||
ln_f,
|
||||
lm_heads,
|
||||
cfg,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.cfg
|
||||
}
|
||||
|
||||
pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let device = idx.device();
|
||||
let (b, _num_hierarchies, t) = idx.dims3()?;
|
||||
let pos = Tensor::arange(0u32, t as u32, device)?;
|
||||
let pos_emb = pos.apply(&self.wpe)?;
|
||||
let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), self.dtype, device)?;
|
||||
for (wte_idx, wte) in self.wtes.iter().enumerate() {
|
||||
let emb = idx.i((.., wte_idx, ..))?.apply(wte)?;
|
||||
tok_emb = (tok_emb + emb)?;
|
||||
}
|
||||
// TODO: speaker embs.
|
||||
let spk_emb = 0f64;
|
||||
let mut xs = (pos_emb.broadcast_add(&tok_emb)? + spk_emb)?;
|
||||
for block in self.h.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
let xs = xs.apply(&self.ln_f)?;
|
||||
let mut logits = Vec::with_capacity(self.lm_heads.len());
|
||||
for lm_head in self.lm_heads.iter() {
|
||||
// non-causal mode only.
|
||||
let ys = xs.apply(lm_head)?;
|
||||
logits.push(ys)
|
||||
}
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod transformer {
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub block_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub n_layer: usize,
|
||||
pub n_head: usize,
|
||||
pub dim: usize,
|
||||
pub speaker_emb_dim: usize,
|
||||
pub intermediate_size: Option<usize>,
|
||||
pub n_local_heads: Option<usize>,
|
||||
pub norm_eps: f64,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn cfg1b_v0_1() -> Self {
|
||||
Self {
|
||||
n_layer: 24,
|
||||
n_head: 16,
|
||||
dim: 2048,
|
||||
vocab_size: 2562,
|
||||
speaker_emb_dim: 256,
|
||||
block_size: 2048,
|
||||
intermediate_size: None,
|
||||
n_local_heads: None,
|
||||
norm_eps: 1e-5,
|
||||
}
|
||||
}
|
||||
|
||||
fn n_local_heads(&self) -> usize {
|
||||
self.n_local_heads.unwrap_or(self.n_head)
|
||||
}
|
||||
|
||||
fn head_dim(&self) -> usize {
|
||||
self.dim / self.n_head
|
||||
}
|
||||
|
||||
fn intermediate_size(&self) -> usize {
|
||||
match self.intermediate_size {
|
||||
Some(intermediate_size) => intermediate_size,
|
||||
None => {
|
||||
let hidden_dim = self.dim * 4;
|
||||
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
|
||||
(n_hidden + 255) / 256 * 256
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FeedForward {
|
||||
w1: Linear,
|
||||
w2: Linear,
|
||||
w3: Linear,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let i_size = cfg.intermediate_size();
|
||||
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
|
||||
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
|
||||
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
|
||||
Ok(Self { w1, w2, w3 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
|
||||
swiglu.apply(&self.w2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
wqkv: Linear,
|
||||
wo: Linear,
|
||||
dim: usize,
|
||||
kv_size: usize,
|
||||
n_local_heads: usize,
|
||||
head_dim: usize,
|
||||
n_head: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let n_local_heads = cfg.n_local_heads();
|
||||
let head_dim = cfg.head_dim();
|
||||
let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
|
||||
let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
|
||||
let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
wo,
|
||||
dim: cfg.dim,
|
||||
kv_size: n_local_heads * head_dim,
|
||||
n_local_heads,
|
||||
head_dim,
|
||||
n_head: cfg.n_head,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seqlen, _) = xs.dims3()?;
|
||||
|
||||
let qkv = xs.apply(&self.wqkv)?;
|
||||
let q = qkv.narrow(D::Minus1, 0, self.dim)?;
|
||||
let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
|
||||
let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
|
||||
let q = q
|
||||
.reshape((b_sz, seqlen, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &k], 2)?;
|
||||
let v = Tensor::cat(&[prev_v, &v], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
|
||||
let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
|
||||
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = attn_weights.broadcast_add(mask)?;
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seqlen, self.dim))?
|
||||
.apply(&self.wo)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
attention: Attention,
|
||||
feed_forward: FeedForward,
|
||||
ffn_norm: RmsNorm,
|
||||
attention_norm: RmsNorm,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let attention = Attention::new(cfg, vb.pp("attention"))?;
|
||||
let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
|
||||
let ffn_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
|
||||
let attention_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
feed_forward,
|
||||
ffn_norm,
|
||||
attention_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||
let hs = xs.apply(&self.attention_norm)?;
|
||||
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
|
||||
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.attention.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
tok_embeddings: Embedding,
|
||||
pos_embeddings: Embedding,
|
||||
speaker_cond_pos: Linear,
|
||||
layers: Vec<Block>,
|
||||
norm: RmsNorm,
|
||||
output: Linear,
|
||||
spk_cond_mask: Tensor,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let tok_embeddings = embedding(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
|
||||
let pos_embeddings = embedding(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
|
||||
let speaker_cond_pos = linear_b(
|
||||
cfg.speaker_emb_dim,
|
||||
cfg.dim,
|
||||
false,
|
||||
vb.pp("speaker_cond_pos"),
|
||||
)?;
|
||||
let mut layers = Vec::with_capacity(cfg.n_layer);
|
||||
let vb_l = vb.pp("layers");
|
||||
for layer_idx in 0..cfg.n_layer {
|
||||
let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
|
||||
let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
|
||||
let dtype = vb.dtype();
|
||||
let spk_cond_mask = Tensor::cat(
|
||||
&[
|
||||
Tensor::ones((1, 1, cfg.dim), dtype, vb.device())?,
|
||||
Tensor::zeros((1, 1, cfg.dim), dtype, vb.device())?,
|
||||
],
|
||||
0,
|
||||
)?;
|
||||
Ok(Self {
|
||||
tok_embeddings,
|
||||
pos_embeddings,
|
||||
speaker_cond_pos,
|
||||
layers,
|
||||
norm,
|
||||
output,
|
||||
spk_cond_mask,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seqlen) = xs.dims2()?;
|
||||
let mask: Vec<_> = (0..seqlen)
|
||||
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
|
||||
let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
|
||||
let tok_embeddings = xs.apply(&self.tok_embeddings)?;
|
||||
let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
|
||||
let mut xs = tok_embeddings
|
||||
.broadcast_add(&pos_embeddings)?
|
||||
.broadcast_add(
|
||||
&spk_emb
|
||||
.apply(&self.speaker_cond_pos)?
|
||||
.broadcast_mul(&self.spk_cond_mask)?,
|
||||
)?;
|
||||
let mask = mask.to_dtype(xs.dtype())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, pos, &mask)?
|
||||
}
|
||||
xs.narrow(1, seqlen - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod adapters {
|
||||
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py
|
||||
pub struct TiltedEncodec {
|
||||
end_of_audio_token: u32,
|
||||
}
|
||||
|
||||
impl TiltedEncodec {
|
||||
pub fn new(end_of_audio_token: u32) -> Self {
|
||||
Self { end_of_audio_token }
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
|
||||
let mut text_ids = vec![];
|
||||
let mut extracted_audio_ids = vec![];
|
||||
let mut min_audio_ids_len = usize::MAX;
|
||||
for (book_id, tokens) in tokens.iter().enumerate() {
|
||||
let mut audio_ids = vec![];
|
||||
for &t in tokens.iter() {
|
||||
#[allow(clippy::comparison_chain)]
|
||||
if t > self.end_of_audio_token {
|
||||
if book_id == 0 {
|
||||
text_ids.push(t)
|
||||
}
|
||||
} else if t < self.end_of_audio_token {
|
||||
audio_ids.push(t)
|
||||
}
|
||||
}
|
||||
min_audio_ids_len = usize::min(min_audio_ids_len, audio_ids.len());
|
||||
extracted_audio_ids.push(audio_ids)
|
||||
}
|
||||
for audio_ids in extracted_audio_ids.iter_mut() {
|
||||
audio_ids.truncate(min_audio_ids_len)
|
||||
}
|
||||
(text_ids, extracted_audio_ids)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4
|
||||
pub struct FlattenedInterleavedEncodec2Codebook {
|
||||
end_of_audio_token: u32,
|
||||
}
|
||||
|
||||
impl FlattenedInterleavedEncodec2Codebook {
|
||||
pub fn new(end_of_audio_token: u32) -> Self {
|
||||
Self { end_of_audio_token }
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
|
||||
let mut text_ids = vec![];
|
||||
let mut audio_ids1 = vec![];
|
||||
let mut audio_ids2 = vec![];
|
||||
for &t in tokens.iter() {
|
||||
#[allow(clippy::comparison_chain)]
|
||||
if t < self.end_of_audio_token {
|
||||
audio_ids1.push(t)
|
||||
} else if t < 2 * self.end_of_audio_token {
|
||||
audio_ids2.push(t - self.end_of_audio_token)
|
||||
} else {
|
||||
text_ids.push(t)
|
||||
}
|
||||
}
|
||||
(text_ids, audio_ids1, audio_ids2)
|
||||
}
|
||||
}
|
||||
}
|
@ -8,17 +8,13 @@ pub mod convnext;
|
||||
pub mod dinov2;
|
||||
pub mod distilbert;
|
||||
pub mod efficientnet;
|
||||
pub mod efficientvit;
|
||||
pub mod encodec;
|
||||
pub mod falcon;
|
||||
pub mod gemma;
|
||||
pub mod jina_bert;
|
||||
pub mod llama;
|
||||
pub mod llama2_c;
|
||||
pub mod llama2_c_weights;
|
||||
pub mod mamba;
|
||||
pub mod marian;
|
||||
pub mod metavoice;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mixtral;
|
||||
@ -33,24 +29,20 @@ pub mod quantized_llama2_c;
|
||||
pub mod quantized_mistral;
|
||||
pub mod quantized_mixformer;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_rwkv_v5;
|
||||
pub mod quantized_rwkv_v6;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod qwen2;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod rwkv_v5;
|
||||
pub mod rwkv_v6;
|
||||
pub mod segformer;
|
||||
pub mod segment_anything;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
pub mod starcoder2;
|
||||
pub mod t5;
|
||||
pub mod trocr;
|
||||
pub mod vgg;
|
||||
pub mod vit;
|
||||
pub mod vocos;
|
||||
pub mod whisper;
|
||||
pub mod with_tracing;
|
||||
pub mod wuerstchen;
|
||||
|
@ -157,16 +157,16 @@ struct LayerWeights {
|
||||
head_dim: usize,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
neg_inf: Tensor,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span_attn: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
span_mlp: tracing::Span,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
@ -240,7 +240,7 @@ impl LayerWeights {
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, &self.neg_inf)?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
@ -298,7 +298,6 @@ impl ModelWeights {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||
@ -338,7 +337,6 @@ impl ModelWeights {
|
||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
@ -387,7 +385,6 @@ impl ModelWeights {
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
@ -458,7 +455,6 @@ impl ModelWeights {
|
||||
head_dim: embedding_length / head_count,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
|
@ -7,7 +7,6 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||
xs / (xs.neg()?.exp()? + 1.0)?
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
@ -16,13 +15,14 @@ struct CausalSelfAttention {
|
||||
n_head: usize,
|
||||
n_key_value_head: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
|
||||
let cos = cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
|
||||
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
|
||||
let cos = cos.unsqueeze(1)?;
|
||||
let sin = sin.unsqueeze(1)?;
|
||||
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
|
||||
@ -36,13 +36,7 @@ impl CausalSelfAttention {
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
@ -52,15 +46,16 @@ impl CausalSelfAttention {
|
||||
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos, cache)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos, cache)?;
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if cache.use_kv_cache {
|
||||
if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||
}
|
||||
cache.kvs[block_idx] = Some((k.clone(), v.clone()))
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
@ -71,7 +66,7 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
@ -95,7 +90,7 @@ impl CausalSelfAttention {
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let size_in = cfg.dim;
|
||||
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
|
||||
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
|
||||
@ -111,6 +106,7 @@ impl CausalSelfAttention {
|
||||
n_head: cfg.n_heads,
|
||||
n_key_value_head: cfg.n_kv_heads,
|
||||
head_dim: cfg.dim / cfg.n_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -122,7 +118,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
@ -153,7 +148,6 @@ impl Mlp {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
@ -171,23 +165,17 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
x: &Tensor,
|
||||
index_pos: usize,
|
||||
block_idx: usize,
|
||||
cache: &mut Cache,
|
||||
) -> Result<Tensor> {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?;
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm =
|
||||
@ -201,7 +189,6 @@ impl Block {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QLlama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
@ -211,23 +198,23 @@ pub struct QLlama {
|
||||
}
|
||||
|
||||
impl QLlama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx, cache)?;
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
|
||||
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.n_layers)
|
||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap())
|
||||
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap())
|
||||
.collect();
|
||||
Ok(Self {
|
||||
wte,
|
||||
|
@ -1,286 +0,0 @@
|
||||
use crate::{
|
||||
quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
|
||||
quantized_var_builder::VarBuilder,
|
||||
};
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
use candle_nn::{GroupNorm, LayerNorm, Module};
|
||||
|
||||
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SelfAttention {
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
gate: Linear,
|
||||
output: Linear,
|
||||
ln_x: candle_nn::GroupNorm,
|
||||
time_mix_key: Tensor,
|
||||
time_mix_value: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
time_decay: Tensor,
|
||||
time_faaaa: Tensor,
|
||||
time_mix_gate: Tensor,
|
||||
layer_id: usize,
|
||||
n_attn_heads: usize,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let attn_hidden_size = cfg.attention_hidden_size;
|
||||
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
|
||||
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
|
||||
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
|
||||
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
|
||||
|
||||
let vb_x = vb.pp("ln_x");
|
||||
let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
|
||||
let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
|
||||
|
||||
let ln_x = GroupNorm::new(
|
||||
ln_x_weight,
|
||||
ln_x_bias,
|
||||
hidden_size,
|
||||
hidden_size / cfg.head_size,
|
||||
1e-5,
|
||||
)?;
|
||||
|
||||
let time_mix_key = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_key")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_value = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_value")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_receptance = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
|
||||
.dequantize(vb.device())?;
|
||||
let n_attn_heads = cfg.hidden_size / cfg.head_size;
|
||||
let time_decay = vb
|
||||
.get((n_attn_heads, cfg.head_size), "time_decay")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_faaaa = vb
|
||||
.get((n_attn_heads, cfg.head_size), "time_faaaa")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_gate = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_gate")?
|
||||
.dequantize(vb.device())?;
|
||||
Ok(Self {
|
||||
key,
|
||||
value,
|
||||
receptance,
|
||||
gate,
|
||||
output,
|
||||
ln_x,
|
||||
time_mix_key,
|
||||
time_mix_value,
|
||||
time_mix_receptance,
|
||||
time_decay,
|
||||
time_faaaa,
|
||||
time_mix_gate,
|
||||
layer_id,
|
||||
n_attn_heads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let h = self.time_decay.dim(0)?;
|
||||
let (b, t, s) = xs.dims3()?;
|
||||
let s = s / h;
|
||||
let (receptance, key, value, gate) = {
|
||||
// extract key-value
|
||||
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
|
||||
let shifted = if shifted.rank() == 2 {
|
||||
shifted.unsqueeze(1)?
|
||||
} else {
|
||||
shifted
|
||||
};
|
||||
let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
|
||||
let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
|
||||
let receptance = ((xs * &self.time_mix_receptance)?
|
||||
+ &shifted * (1.0 - &self.time_mix_receptance)?)?;
|
||||
let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
|
||||
|
||||
let key = self.key.forward(&key)?;
|
||||
let value = self.value.forward(&value)?;
|
||||
let receptance = self.receptance.forward(&receptance)?;
|
||||
let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
|
||||
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
|
||||
(receptance, key, value, gate)
|
||||
};
|
||||
// linear attention
|
||||
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
|
||||
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
|
||||
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
|
||||
let time_decay = self
|
||||
.time_decay
|
||||
.exp()?
|
||||
.neg()?
|
||||
.exp()?
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
let time_faaaa =
|
||||
self.time_faaaa
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
|
||||
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||
for t_ in 0..t {
|
||||
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
|
||||
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
let at = kt.matmul(&vt)?;
|
||||
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||
state_ = (&at + time_decay.broadcast_mul(&state_))?;
|
||||
out.push(out_)
|
||||
}
|
||||
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
|
||||
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
|
||||
let out = (out * gate)?.apply(&self.output)?;
|
||||
state.per_layer[self.layer_id].linear_attention = state_;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FeedForward {
|
||||
time_mix_key: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
layer_id: usize,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let int_size = cfg
|
||||
.intermediate_size
|
||||
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
|
||||
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
|
||||
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
|
||||
let time_mix_key = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_key")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_receptance = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
|
||||
.dequantize(vb.device())?;
|
||||
Ok(Self {
|
||||
key,
|
||||
receptance,
|
||||
value,
|
||||
time_mix_key,
|
||||
time_mix_receptance,
|
||||
layer_id,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let shifted = &state.per_layer[self.layer_id].feed_forward;
|
||||
let key = (xs.broadcast_mul(&self.time_mix_key)?
|
||||
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
|
||||
let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
|
||||
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
|
||||
let key = key.apply(&self.key)?.relu()?.sqr()?;
|
||||
let value = key.apply(&self.value)?;
|
||||
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
|
||||
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
|
||||
let xs = (receptance * value)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
pre_ln: Option<LayerNorm>,
|
||||
ln1: LayerNorm,
|
||||
ln2: LayerNorm,
|
||||
attention: SelfAttention,
|
||||
feed_forward: FeedForward,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
|
||||
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
|
||||
let pre_ln = if layer_id == 0 {
|
||||
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
|
||||
Some(ln)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
|
||||
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
|
||||
Ok(Self {
|
||||
pre_ln,
|
||||
ln1,
|
||||
ln2,
|
||||
attention,
|
||||
feed_forward,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let xs = match self.pre_ln.as_ref() {
|
||||
None => xs.clone(),
|
||||
Some(pre_ln) => xs.apply(pre_ln)?,
|
||||
};
|
||||
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
|
||||
let xs = (xs + attention)?;
|
||||
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
|
||||
let xs = (xs + feed_forward)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embeddings: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_out: LayerNorm,
|
||||
head: Linear,
|
||||
rescale_every: usize,
|
||||
layers_are_rescaled: bool,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("rwkv");
|
||||
let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
|
||||
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_b = vb_m.pp("blocks");
|
||||
for block_index in 0..cfg.num_hidden_layers {
|
||||
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
|
||||
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
blocks,
|
||||
ln_out,
|
||||
head,
|
||||
rescale_every: cfg.rescale_every,
|
||||
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let (_b_size, _seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embeddings)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
xs = block.forward(&xs, state)?;
|
||||
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
|
||||
xs = (xs / 2.)?
|
||||
}
|
||||
}
|
||||
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
|
||||
state.pos += 1;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
@ -1,332 +0,0 @@
|
||||
use crate::{
|
||||
quantized_nn::{layer_norm, linear_no_bias as linear, Embedding, Linear},
|
||||
quantized_var_builder::VarBuilder,
|
||||
};
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
use candle_nn::{GroupNorm, LayerNorm, Module};
|
||||
|
||||
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SelfAttention {
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
gate: Linear,
|
||||
output: Linear,
|
||||
ln_x: candle_nn::GroupNorm,
|
||||
time_mix_x: Tensor,
|
||||
time_mix_w: Tensor,
|
||||
time_mix_key: Tensor,
|
||||
time_mix_value: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
time_decay: Tensor,
|
||||
time_faaaa: Tensor,
|
||||
time_mix_gate: Tensor,
|
||||
time_decay_w1: Tensor,
|
||||
time_decay_w2: Tensor,
|
||||
time_mix_w1: Tensor,
|
||||
time_mix_w2: Tensor,
|
||||
layer_id: usize,
|
||||
n_attn_heads: usize,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let attn_hidden_size = cfg.attention_hidden_size;
|
||||
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
|
||||
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
|
||||
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
|
||||
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
|
||||
|
||||
let vb_x = vb.pp("ln_x");
|
||||
let ln_x_weight = vb_x.get(hidden_size, "weight")?.dequantize(vb.device())?;
|
||||
let ln_x_bias = vb_x.get(hidden_size, "bias")?.dequantize(vb.device())?;
|
||||
|
||||
let ln_x = GroupNorm::new(
|
||||
ln_x_weight,
|
||||
ln_x_bias,
|
||||
hidden_size,
|
||||
hidden_size / cfg.head_size,
|
||||
1e-5,
|
||||
)?;
|
||||
|
||||
let time_mix_x = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_x")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_w = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_w")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_key = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_key")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_value = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_value")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_receptance = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
|
||||
.dequantize(vb.device())?;
|
||||
let n_attn_heads = cfg.hidden_size / cfg.head_size;
|
||||
let time_decay = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_decay")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_faaaa = vb
|
||||
.get((n_attn_heads, cfg.head_size), "time_faaaa")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_gate = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_gate")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_decay_w1 = vb
|
||||
.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_decay_w2 = vb
|
||||
.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_w1 = vb
|
||||
.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_w2 = vb
|
||||
.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?
|
||||
.dequantize(vb.device())?;
|
||||
Ok(Self {
|
||||
key,
|
||||
value,
|
||||
receptance,
|
||||
gate,
|
||||
output,
|
||||
ln_x,
|
||||
time_mix_x,
|
||||
time_mix_w,
|
||||
time_mix_key,
|
||||
time_mix_value,
|
||||
time_mix_receptance,
|
||||
time_decay,
|
||||
time_faaaa,
|
||||
time_mix_gate,
|
||||
time_decay_w1,
|
||||
time_decay_w2,
|
||||
time_mix_w1,
|
||||
time_mix_w2,
|
||||
layer_id,
|
||||
n_attn_heads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let h = self.n_attn_heads;
|
||||
let (b, t, s) = xs.dims3()?;
|
||||
let s = s / h;
|
||||
let (receptance, key, value, gate, w) = {
|
||||
// extract key-value
|
||||
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
|
||||
let shifted = if shifted.rank() == 2 {
|
||||
shifted.unsqueeze(1)?
|
||||
} else {
|
||||
shifted
|
||||
};
|
||||
|
||||
let sx = (&shifted - xs)?;
|
||||
let xxx = (xs + &sx * &self.time_mix_x)?;
|
||||
let xxx = xxx
|
||||
.broadcast_matmul(&self.time_mix_w1)?
|
||||
.tanh()?
|
||||
.reshape((b * t, 5, ()))?
|
||||
.transpose(0, 1)?;
|
||||
|
||||
let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
|
||||
|
||||
let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
|
||||
|
||||
let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
|
||||
let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
|
||||
let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
|
||||
let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
|
||||
let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
|
||||
|
||||
let w = (&self.time_decay
|
||||
+ xw.broadcast_matmul(&self.time_decay_w1)?
|
||||
.tanh()?
|
||||
.broadcast_matmul(&self.time_decay_w2)?)?
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
|
||||
let key = self.key.forward(&xk)?;
|
||||
let value = self.value.forward(&xv)?;
|
||||
let receptance = self.receptance.forward(&xr)?;
|
||||
let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
|
||||
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
|
||||
(receptance, key, value, gate, w)
|
||||
};
|
||||
|
||||
// linear attention
|
||||
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
|
||||
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
|
||||
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
|
||||
let w = w.exp()?.neg()?.exp()?;
|
||||
|
||||
let time_faaaa =
|
||||
self.time_faaaa
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
|
||||
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||
for t_ in 0..t {
|
||||
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
|
||||
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
let at = kt.matmul(&vt)?;
|
||||
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||
state_ = (&at + w.broadcast_mul(&state_))?;
|
||||
out.push(out_)
|
||||
}
|
||||
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
|
||||
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
|
||||
let out = (out * gate)?.apply(&self.output)?;
|
||||
state.per_layer[self.layer_id].linear_attention = state_;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FeedForward {
|
||||
time_mix_key: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
layer_id: usize,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let int_size = cfg
|
||||
.intermediate_size
|
||||
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
|
||||
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
|
||||
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
|
||||
let time_mix_key = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_key")?
|
||||
.dequantize(vb.device())?;
|
||||
let time_mix_receptance = vb
|
||||
.get((1, 1, cfg.hidden_size), "time_mix_receptance")?
|
||||
.dequantize(vb.device())?;
|
||||
Ok(Self {
|
||||
key,
|
||||
receptance,
|
||||
value,
|
||||
time_mix_key,
|
||||
time_mix_receptance,
|
||||
layer_id,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let shifted = state.per_layer[self.layer_id]
|
||||
.feed_forward
|
||||
.broadcast_sub(xs)?;
|
||||
let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
|
||||
let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
|
||||
let key = key.apply(&self.key)?.relu()?.sqr()?;
|
||||
let value = key.apply(&self.value)?;
|
||||
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
|
||||
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
|
||||
let xs = (receptance * value)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
pre_ln: Option<LayerNorm>,
|
||||
ln1: LayerNorm,
|
||||
ln2: LayerNorm,
|
||||
attention: SelfAttention,
|
||||
feed_forward: FeedForward,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
|
||||
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
|
||||
let pre_ln = if layer_id == 0 {
|
||||
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
|
||||
Some(ln)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
|
||||
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
|
||||
Ok(Self {
|
||||
pre_ln,
|
||||
ln1,
|
||||
ln2,
|
||||
attention,
|
||||
feed_forward,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let xs = match self.pre_ln.as_ref() {
|
||||
None => xs.clone(),
|
||||
Some(pre_ln) => xs.apply(pre_ln)?,
|
||||
};
|
||||
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
|
||||
let xs = (xs + attention)?;
|
||||
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
|
||||
let xs = (xs + feed_forward)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embeddings: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_out: LayerNorm,
|
||||
head: Linear,
|
||||
rescale_every: usize,
|
||||
layers_are_rescaled: bool,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("rwkv");
|
||||
let embeddings = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
|
||||
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_b = vb_m.pp("blocks");
|
||||
for block_index in 0..cfg.num_hidden_layers {
|
||||
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
|
||||
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
blocks,
|
||||
ln_out,
|
||||
head,
|
||||
rescale_every: cfg.rescale_every,
|
||||
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let (_b_size, _seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embeddings)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
xs = block.forward(&xs, state)?;
|
||||
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
|
||||
xs = (xs / 2.)?
|
||||
}
|
||||
}
|
||||
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
|
||||
state.pos += 1;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
@ -186,14 +186,10 @@ impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_eps,
|
||||
vb.pp("input_layernorm"),
|
||||
)?;
|
||||
let input_layernorm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_eps,
|
||||
cfg.norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
@ -244,7 +240,7 @@ impl Model {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?;
|
||||
let norm = layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
|
@ -22,15 +22,15 @@ pub struct Config {
|
||||
pub rescale_every: usize,
|
||||
}
|
||||
|
||||
pub struct StatePerLayer {
|
||||
pub extract_key_value: Tensor,
|
||||
pub linear_attention: Tensor,
|
||||
pub feed_forward: Tensor,
|
||||
struct StatePerLayer {
|
||||
extract_key_value: Tensor,
|
||||
linear_attention: Tensor,
|
||||
feed_forward: Tensor,
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
pub per_layer: Vec<StatePerLayer>,
|
||||
pub pos: usize,
|
||||
per_layer: Vec<StatePerLayer>,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl State {
|
||||
@ -124,7 +124,7 @@ impl SelfAttention {
|
||||
let (b, t, s) = xs.dims3()?;
|
||||
let s = s / h;
|
||||
let (receptance, key, value, gate) = {
|
||||
// extract key-value
|
||||
// exctract key-value
|
||||
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
|
||||
let shifted = if shifted.rank() == 2 {
|
||||
shifted.unsqueeze(1)?
|
||||
@ -164,9 +164,10 @@ impl SelfAttention {
|
||||
|
||||
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||
for t_ in 0..t {
|
||||
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
|
||||
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
//
|
||||
let rt = receptance.i((.., .., t_..t_ + 1))?;
|
||||
let kt = key.i((.., .., .., t_..t_ + 1))?;
|
||||
let vt = value.i((.., .., t_..t_ + 1))?;
|
||||
let at = kt.matmul(&vt)?;
|
||||
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||
|
@ -1,295 +0,0 @@
|
||||
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
|
||||
use candle::{IndexOp, Result, Tensor};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
|
||||
pub use crate::models::rwkv_v5::{Config, State, Tokenizer};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SelfAttention {
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
gate: Linear,
|
||||
output: Linear,
|
||||
ln_x: candle_nn::GroupNorm,
|
||||
time_mix_x: Tensor,
|
||||
time_mix_w: Tensor,
|
||||
time_mix_key: Tensor,
|
||||
time_mix_value: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
time_decay: Tensor,
|
||||
time_faaaa: Tensor,
|
||||
time_mix_gate: Tensor,
|
||||
time_decay_w1: Tensor,
|
||||
time_decay_w2: Tensor,
|
||||
time_mix_w1: Tensor,
|
||||
time_mix_w2: Tensor,
|
||||
layer_id: usize,
|
||||
n_attn_heads: usize,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let attn_hidden_size = cfg.attention_hidden_size;
|
||||
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
|
||||
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
|
||||
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
|
||||
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
|
||||
let ln_x = candle_nn::group_norm(
|
||||
hidden_size / cfg.head_size,
|
||||
hidden_size,
|
||||
1e-5,
|
||||
vb.pp("ln_x"),
|
||||
)?;
|
||||
|
||||
let time_mix_x = vb.get((1, 1, cfg.hidden_size), "time_mix_x")?;
|
||||
let time_mix_w = vb.get((1, 1, cfg.hidden_size), "time_mix_w")?;
|
||||
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||
let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
|
||||
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||
let n_attn_heads = cfg.hidden_size / cfg.head_size;
|
||||
let time_decay = vb.get((1, 1, cfg.hidden_size), "time_decay")?;
|
||||
let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
|
||||
let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
|
||||
let time_decay_w1 = vb.get((cfg.hidden_size, n_attn_heads * 2), "time_decay_w1")?;
|
||||
let time_decay_w2 = vb.get((n_attn_heads * 2, cfg.hidden_size), "time_decay_w2")?;
|
||||
let time_mix_w1 = vb.get((cfg.hidden_size, n_attn_heads * 5), "time_mix_w1")?;
|
||||
let time_mix_w2 = vb.get((5, n_attn_heads, cfg.hidden_size), "time_mix_w2")?;
|
||||
Ok(Self {
|
||||
key,
|
||||
value,
|
||||
receptance,
|
||||
gate,
|
||||
output,
|
||||
ln_x,
|
||||
time_mix_x,
|
||||
time_mix_w,
|
||||
time_mix_key,
|
||||
time_mix_value,
|
||||
time_mix_receptance,
|
||||
time_decay,
|
||||
time_faaaa,
|
||||
time_mix_gate,
|
||||
time_decay_w1,
|
||||
time_decay_w2,
|
||||
time_mix_w1,
|
||||
time_mix_w2,
|
||||
layer_id,
|
||||
n_attn_heads,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let h = self.n_attn_heads;
|
||||
let (b, t, s) = xs.dims3()?;
|
||||
let s = s / h;
|
||||
let (receptance, key, value, gate, w) = {
|
||||
// extract key-value
|
||||
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
|
||||
let shifted = if shifted.rank() == 2 {
|
||||
shifted.unsqueeze(1)?
|
||||
} else {
|
||||
shifted
|
||||
};
|
||||
|
||||
let sx = (&shifted - xs)?;
|
||||
let xxx = (xs + &sx * &self.time_mix_x)?;
|
||||
let xxx = xxx
|
||||
.broadcast_matmul(&self.time_mix_w1)?
|
||||
.tanh()?
|
||||
.reshape((b * t, 5, ()))?
|
||||
.transpose(0, 1)?;
|
||||
|
||||
let xxx = xxx.matmul(&self.time_mix_w2)?.reshape((5, b, t, ()))?;
|
||||
|
||||
let (mw, mk, mv, mr, mg) = (xxx.i(0)?, xxx.i(1)?, xxx.i(2)?, xxx.i(3)?, xxx.i(4)?);
|
||||
|
||||
let xw = (xs + &sx * (&self.time_mix_w + &mw)?)?;
|
||||
let xk = (xs + &sx * (&self.time_mix_key + &mk)?)?;
|
||||
let xv = (xs + &sx * (&self.time_mix_value + &mv)?)?;
|
||||
let xr = (xs + &sx * (&self.time_mix_receptance + &mr)?)?;
|
||||
let xg = (xs + &sx * (&self.time_mix_gate + &mg)?)?;
|
||||
|
||||
let w = (&self.time_decay
|
||||
+ xw.broadcast_matmul(&self.time_decay_w1)?
|
||||
.tanh()?
|
||||
.broadcast_matmul(&self.time_decay_w2)?)?
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
|
||||
let key = self.key.forward(&xk)?;
|
||||
let value = self.value.forward(&xv)?;
|
||||
let receptance = self.receptance.forward(&xr)?;
|
||||
let gate = candle_nn::ops::silu(&self.gate.forward(&xg)?)?;
|
||||
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
|
||||
(receptance, key, value, gate, w)
|
||||
};
|
||||
|
||||
// linear attention
|
||||
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
|
||||
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
|
||||
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
|
||||
|
||||
let w = w.exp()?.neg()?.exp()?;
|
||||
|
||||
let time_faaaa =
|
||||
self.time_faaaa
|
||||
.reshape(((), 1, 1))?
|
||||
.reshape((self.n_attn_heads, (), 1))?;
|
||||
|
||||
let mut out: Vec<Tensor> = Vec::with_capacity(t);
|
||||
for t_ in 0..t {
|
||||
let rt = receptance.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
let kt = key.i((.., .., .., t_..t_ + 1))?.contiguous()?;
|
||||
let vt = value.i((.., .., t_..t_ + 1))?.contiguous()?;
|
||||
let at = kt.matmul(&vt)?;
|
||||
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
|
||||
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
|
||||
state_ = (&at + w.broadcast_mul(&state_))?;
|
||||
out.push(out_)
|
||||
}
|
||||
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
|
||||
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
|
||||
let out = (out * gate)?.apply(&self.output)?;
|
||||
state.per_layer[self.layer_id].linear_attention = state_;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FeedForward {
|
||||
time_mix_key: Tensor,
|
||||
time_mix_receptance: Tensor,
|
||||
key: Linear,
|
||||
receptance: Linear,
|
||||
value: Linear,
|
||||
layer_id: usize,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let int_size = cfg
|
||||
.intermediate_size
|
||||
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
|
||||
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
|
||||
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
|
||||
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
|
||||
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
|
||||
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
|
||||
Ok(Self {
|
||||
key,
|
||||
receptance,
|
||||
value,
|
||||
time_mix_key,
|
||||
time_mix_receptance,
|
||||
layer_id,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let shifted = state.per_layer[self.layer_id]
|
||||
.feed_forward
|
||||
.broadcast_sub(xs)?;
|
||||
let key = (xs + shifted.broadcast_mul(&self.time_mix_key)?)?;
|
||||
let receptance = (xs + shifted.broadcast_mul(&self.time_mix_receptance)?)?;
|
||||
let key = key.apply(&self.key)?.relu()?.sqr()?;
|
||||
let value = key.apply(&self.value)?;
|
||||
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
|
||||
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
|
||||
let xs = (receptance * value)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
pre_ln: Option<LayerNorm>,
|
||||
ln1: LayerNorm,
|
||||
ln2: LayerNorm,
|
||||
attention: SelfAttention,
|
||||
feed_forward: FeedForward,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
|
||||
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
|
||||
let pre_ln = if layer_id == 0 {
|
||||
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
|
||||
Some(ln)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
|
||||
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
|
||||
Ok(Self {
|
||||
pre_ln,
|
||||
ln1,
|
||||
ln2,
|
||||
attention,
|
||||
feed_forward,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let xs = match self.pre_ln.as_ref() {
|
||||
None => xs.clone(),
|
||||
Some(pre_ln) => xs.apply(pre_ln)?,
|
||||
};
|
||||
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
|
||||
let xs = (xs + attention)?;
|
||||
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
|
||||
let xs = (xs + feed_forward)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embeddings: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_out: LayerNorm,
|
||||
head: Linear,
|
||||
rescale_every: usize,
|
||||
layers_are_rescaled: bool,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("rwkv");
|
||||
let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
|
||||
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_b = vb_m.pp("blocks");
|
||||
for block_index in 0..cfg.num_hidden_layers {
|
||||
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
|
||||
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
blocks,
|
||||
ln_out,
|
||||
head,
|
||||
rescale_every: cfg.rescale_every,
|
||||
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
||||
let (_b_size, _seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embeddings)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
xs = block.forward(&xs, state)?;
|
||||
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
|
||||
xs = (xs / 2.)?
|
||||
}
|
||||
}
|
||||
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
|
||||
state.pos += 1;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
@ -1,705 +0,0 @@
|
||||
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
|
||||
use candle::{Module, ModuleT, Result, Tensor, D};
|
||||
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
#[serde(default)]
|
||||
pub id2label: HashMap<String, String>,
|
||||
pub num_channels: usize,
|
||||
pub num_encoder_blocks: usize,
|
||||
pub depths: Vec<usize>,
|
||||
pub sr_ratios: Vec<usize>,
|
||||
pub hidden_sizes: Vec<usize>,
|
||||
pub patch_sizes: Vec<usize>,
|
||||
pub strides: Vec<usize>,
|
||||
pub num_attention_heads: Vec<usize>,
|
||||
pub mlp_ratios: Vec<usize>,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub layer_norm_eps: f64,
|
||||
pub decoder_hidden_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerOverlapPatchEmbeddings {
|
||||
projection: Conv2d,
|
||||
layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl SegformerOverlapPatchEmbeddings {
|
||||
fn new(
|
||||
config: &Config,
|
||||
patch_size: usize,
|
||||
stride: usize,
|
||||
num_channels: usize,
|
||||
hidden_size: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let projection = conv2d(
|
||||
num_channels,
|
||||
hidden_size,
|
||||
patch_size,
|
||||
Conv2dConfig {
|
||||
stride,
|
||||
padding: patch_size / 2,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("proj"),
|
||||
)?;
|
||||
let layer_norm =
|
||||
candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
projection,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerOverlapPatchEmbeddings {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let embeddings = self.projection.forward(x)?;
|
||||
let shape = embeddings.shape();
|
||||
// [B, C, H, W] -> [B, H * W, C]
|
||||
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
// [B, H * W, C] -> [B, C, H, W]
|
||||
let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerEfficientSelfAttention {
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
sr: Option<Conv2d>,
|
||||
layer_norm: Option<layer_norm::LayerNorm>,
|
||||
}
|
||||
|
||||
impl SegformerEfficientSelfAttention {
|
||||
fn new(
|
||||
config: &Config,
|
||||
hidden_size: usize,
|
||||
num_attention_heads: usize,
|
||||
sequence_reduction_ratio: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
if hidden_size % num_attention_heads != 0 {
|
||||
candle::bail!(
|
||||
"The hidden size {} is not a multiple of the number of attention heads {}",
|
||||
hidden_size,
|
||||
num_attention_heads
|
||||
)
|
||||
}
|
||||
let attention_head_size = hidden_size / num_attention_heads;
|
||||
let all_head_size = num_attention_heads * attention_head_size;
|
||||
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
|
||||
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
|
||||
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
|
||||
let (sr, layer_norm) = if sequence_reduction_ratio > 1 {
|
||||
(
|
||||
Some(conv2d(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
sequence_reduction_ratio,
|
||||
Conv2dConfig {
|
||||
stride: sequence_reduction_ratio,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("sr"),
|
||||
)?),
|
||||
Some(candle_nn::layer_norm(
|
||||
hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("layer_norm"),
|
||||
)?),
|
||||
)
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
Ok(Self {
|
||||
num_attention_heads,
|
||||
attention_head_size,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
sr,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, hidden_states: Tensor) -> Result<Tensor> {
|
||||
let (batch, seq_length, _) = hidden_states.shape().dims3()?;
|
||||
let new_shape = &[
|
||||
batch,
|
||||
seq_length,
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
];
|
||||
let hidden_states = hidden_states.reshape(new_shape)?;
|
||||
let hidden_states = hidden_states.permute((0, 2, 1, 3))?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerEfficientSelfAttention {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// [B, C, H, W] -> [B, H * W, C]
|
||||
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let query = self
|
||||
.transpose_for_scores(self.query.forward(&hidden_states)?)?
|
||||
.contiguous()?;
|
||||
let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) {
|
||||
let hidden_states = sr.forward(x)?;
|
||||
// [B, C, H, W] -> [B, H * W, C]
|
||||
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
layer_norm.forward(&hidden_states)?
|
||||
} else {
|
||||
// already [B, H * W, C]
|
||||
hidden_states
|
||||
};
|
||||
// standard self-attention
|
||||
let key = self
|
||||
.transpose_for_scores(self.key.forward(&hidden_states)?)?
|
||||
.contiguous()?;
|
||||
let value = self
|
||||
.transpose_for_scores(self.value.forward(&hidden_states)?)?
|
||||
.contiguous()?;
|
||||
let attention_scores =
|
||||
(query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?;
|
||||
let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||
let result = attention_scores.matmul(&value)?;
|
||||
let result = result.permute((0, 2, 1, 3))?.contiguous()?;
|
||||
result.flatten_from(D::Minus2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerSelfOutput {
|
||||
dense: Linear,
|
||||
}
|
||||
|
||||
impl SegformerSelfOutput {
|
||||
fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?;
|
||||
Ok(Self { dense })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerSelfOutput {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.dense.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerAttention {
|
||||
attention: SegformerEfficientSelfAttention,
|
||||
output: SegformerSelfOutput,
|
||||
}
|
||||
|
||||
impl SegformerAttention {
|
||||
fn new(
|
||||
config: &Config,
|
||||
hidden_size: usize,
|
||||
num_attention_heads: usize,
|
||||
sequence_reduction_ratio: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let attention = SegformerEfficientSelfAttention::new(
|
||||
config,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_reduction_ratio,
|
||||
vb.pp("self"),
|
||||
)?;
|
||||
let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?;
|
||||
Ok(Self { attention, output })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerAttention {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let attention_output = self.attention.forward(x)?;
|
||||
self.output.forward(&attention_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerDWConv {
|
||||
dw_conv: Conv2d,
|
||||
}
|
||||
|
||||
impl SegformerDWConv {
|
||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let dw_conv = conv2d(
|
||||
dim,
|
||||
dim,
|
||||
3,
|
||||
Conv2dConfig {
|
||||
stride: 1,
|
||||
padding: 1,
|
||||
groups: dim,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("dwconv"),
|
||||
)?;
|
||||
Ok(Self { dw_conv })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerDWConv {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.dw_conv.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerMixFFN {
|
||||
dense1: Linear,
|
||||
dw_conv: SegformerDWConv,
|
||||
act: Activation,
|
||||
dense2: Linear,
|
||||
}
|
||||
|
||||
impl SegformerMixFFN {
|
||||
fn new(
|
||||
config: &Config,
|
||||
in_features: usize,
|
||||
hidden_features: usize,
|
||||
out_features: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?;
|
||||
let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?;
|
||||
let act = config.hidden_act;
|
||||
let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?;
|
||||
Ok(Self {
|
||||
dense1,
|
||||
dw_conv,
|
||||
act,
|
||||
dense2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerMixFFN {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (batch, _, height, width) = x.shape().dims4()?;
|
||||
let hidden_states = self
|
||||
.dense1
|
||||
.forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?;
|
||||
let channels = hidden_states.dim(2)?;
|
||||
let hidden_states = self.dw_conv.forward(
|
||||
&hidden_states
|
||||
.permute((0, 2, 1))?
|
||||
.reshape((batch, channels, height, width))?,
|
||||
)?;
|
||||
let hidden_states = self.act.forward(&hidden_states)?;
|
||||
let hidden_states = self
|
||||
.dense2
|
||||
.forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
|
||||
let channels = hidden_states.dim(2)?;
|
||||
hidden_states
|
||||
.permute((0, 2, 1))?
|
||||
.reshape((batch, channels, height, width))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerLayer {
|
||||
layer_norm_1: candle_nn::LayerNorm,
|
||||
attention: SegformerAttention,
|
||||
layer_norm_2: candle_nn::LayerNorm,
|
||||
mlp: SegformerMixFFN,
|
||||
}
|
||||
|
||||
impl SegformerLayer {
|
||||
fn new(
|
||||
config: &Config,
|
||||
hidden_size: usize,
|
||||
num_attention_heads: usize,
|
||||
sequence_reduction_ratio: usize,
|
||||
mlp_ratio: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?;
|
||||
let attention = SegformerAttention::new(
|
||||
config,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_reduction_ratio,
|
||||
vb.pp("attention"),
|
||||
)?;
|
||||
let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?;
|
||||
let mlp = SegformerMixFFN::new(
|
||||
config,
|
||||
hidden_size,
|
||||
hidden_size * mlp_ratio,
|
||||
hidden_size,
|
||||
vb.pp("mlp"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
layer_norm_1,
|
||||
attention,
|
||||
layer_norm_2,
|
||||
mlp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerLayer {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let shape = x.shape().dims4()?;
|
||||
// [B, C, H, W] -> [B, H * W, C]
|
||||
let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let layer_norm_output = self.layer_norm_1.forward(&hidden_states)?;
|
||||
let layer_norm_output = layer_norm_output.permute((0, 2, 1))?.reshape(shape)?;
|
||||
// attention takes in [B, C, H, W] in order to properly do conv2d (and output [B, H * W, C])
|
||||
let attention_output = self.attention.forward(&layer_norm_output)?;
|
||||
let hidden_states = (attention_output + hidden_states)?;
|
||||
let layer_norm_output = self.layer_norm_2.forward(&hidden_states)?;
|
||||
let mlp_output = self
|
||||
.mlp
|
||||
.forward(&layer_norm_output.permute((0, 2, 1))?.reshape(shape)?)?;
|
||||
hidden_states.permute((0, 2, 1))?.reshape(shape)? + mlp_output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerEncoder {
|
||||
/// config file
|
||||
config: Config,
|
||||
/// a list of embeddings
|
||||
patch_embeddings: Vec<SegformerOverlapPatchEmbeddings>,
|
||||
/// a list of attention blocks, each consisting of layers
|
||||
blocks: Vec<Vec<SegformerLayer>>,
|
||||
/// a final list of layer norms
|
||||
layer_norms: Vec<candle_nn::LayerNorm>,
|
||||
}
|
||||
|
||||
impl SegformerEncoder {
|
||||
fn new(config: Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks);
|
||||
let mut blocks = Vec::with_capacity(config.num_encoder_blocks);
|
||||
let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks);
|
||||
for i in 0..config.num_encoder_blocks {
|
||||
let patch_size = config.patch_sizes[i];
|
||||
let stride = config.strides[i];
|
||||
let hidden_size = config.hidden_sizes[i];
|
||||
let num_channels = if i == 0 {
|
||||
config.num_channels
|
||||
} else {
|
||||
config.hidden_sizes[i - 1]
|
||||
};
|
||||
patch_embeddings.push(SegformerOverlapPatchEmbeddings::new(
|
||||
&config,
|
||||
patch_size,
|
||||
stride,
|
||||
num_channels,
|
||||
hidden_size,
|
||||
vb.pp(&format!("patch_embeddings.{}", i)),
|
||||
)?);
|
||||
let mut layers = Vec::with_capacity(config.depths[i]);
|
||||
for j in 0..config.depths[i] {
|
||||
let sequence_reduction_ratio = config.sr_ratios[i];
|
||||
let num_attention_heads = config.num_attention_heads[i];
|
||||
let mlp_ratio = config.mlp_ratios[i];
|
||||
layers.push(SegformerLayer::new(
|
||||
&config,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
sequence_reduction_ratio,
|
||||
mlp_ratio,
|
||||
vb.pp(&format!("block.{}.{}", i, j)),
|
||||
)?);
|
||||
}
|
||||
blocks.push(layers);
|
||||
layer_norms.push(layer_norm(
|
||||
hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp(&format!("layer_norm.{}", i)),
|
||||
)?);
|
||||
}
|
||||
Ok(Self {
|
||||
config,
|
||||
patch_embeddings,
|
||||
blocks,
|
||||
layer_norms,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleWithHiddenStates for SegformerEncoder {
|
||||
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks);
|
||||
let mut hidden_states = x.clone();
|
||||
for i in 0..self.config.num_encoder_blocks {
|
||||
hidden_states = self.patch_embeddings[i].forward(&hidden_states)?;
|
||||
for layer in &self.blocks[i] {
|
||||
hidden_states = layer.forward(&hidden_states)?;
|
||||
}
|
||||
let shape = hidden_states.shape().dims4()?;
|
||||
hidden_states =
|
||||
self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?;
|
||||
hidden_states = hidden_states.permute((0, 2, 1))?.reshape(shape)?;
|
||||
all_hidden_states.push(hidden_states.clone());
|
||||
}
|
||||
Ok(all_hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerModel {
|
||||
encoder: SegformerEncoder,
|
||||
}
|
||||
|
||||
impl SegformerModel {
|
||||
fn new(config: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?;
|
||||
Ok(Self { encoder })
|
||||
}
|
||||
}
|
||||
|
||||
impl ModuleWithHiddenStates for SegformerModel {
|
||||
fn forward(&self, x: &Tensor) -> Result<Vec<Tensor>> {
|
||||
self.encoder.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerMLP {
|
||||
proj: Linear,
|
||||
}
|
||||
|
||||
impl SegformerMLP {
|
||||
fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("proj"))?;
|
||||
Ok(Self { proj })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SegformerMLP {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.proj.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SegformerDecodeHead {
|
||||
linear_c: Vec<SegformerMLP>,
|
||||
linear_fuse: candle_nn::Conv2d,
|
||||
batch_norm: candle_nn::BatchNorm,
|
||||
classifier: candle_nn::Conv2d,
|
||||
}
|
||||
|
||||
impl SegformerDecodeHead {
|
||||
fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let mut linear_c = Vec::with_capacity(config.num_encoder_blocks);
|
||||
for i in 0..config.num_encoder_blocks {
|
||||
let hidden_size = config.hidden_sizes[i];
|
||||
linear_c.push(SegformerMLP::new(
|
||||
config,
|
||||
hidden_size,
|
||||
vb.pp(&format!("linear_c.{}", i)),
|
||||
)?);
|
||||
}
|
||||
let linear_fuse = conv2d_no_bias(
|
||||
config.decoder_hidden_size * config.num_encoder_blocks,
|
||||
config.decoder_hidden_size,
|
||||
1,
|
||||
Conv2dConfig::default(),
|
||||
vb.pp("linear_fuse"),
|
||||
)?;
|
||||
let batch_norm = candle_nn::batch_norm(
|
||||
config.decoder_hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("batch_norm"),
|
||||
)?;
|
||||
let classifier = conv2d_no_bias(
|
||||
config.decoder_hidden_size,
|
||||
num_labels,
|
||||
1,
|
||||
Conv2dConfig::default(),
|
||||
vb.pp("classifier"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
linear_c,
|
||||
linear_fuse,
|
||||
batch_norm,
|
||||
classifier,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, encoder_hidden_states: &[Tensor]) -> Result<Tensor> {
|
||||
if encoder_hidden_states.len() != self.linear_c.len() {
|
||||
candle::bail!(
|
||||
"The number of encoder hidden states {} is not equal to the number of linear layers {}",
|
||||
encoder_hidden_states.len(),
|
||||
self.linear_c.len()
|
||||
)
|
||||
}
|
||||
// most fine layer
|
||||
let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?;
|
||||
let mut hidden_states = Vec::with_capacity(self.linear_c.len());
|
||||
for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) {
|
||||
let (batch, _, height, width) = hidden_state.shape().dims4()?;
|
||||
let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?;
|
||||
let hidden_state = hidden_state.permute((0, 2, 1))?.reshape((
|
||||
batch,
|
||||
hidden_state.dim(2)?,
|
||||
height,
|
||||
width,
|
||||
))?;
|
||||
let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?;
|
||||
hidden_states.push(hidden_state);
|
||||
}
|
||||
hidden_states.reverse();
|
||||
let hidden_states = Tensor::cat(&hidden_states, 1)?;
|
||||
let hidden_states = self.linear_fuse.forward(&hidden_states)?;
|
||||
let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?;
|
||||
let hidden_states = hidden_states.relu()?;
|
||||
self.classifier.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
trait ModuleWithHiddenStates {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Vec<Tensor>>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SemanticSegmentationModel {
|
||||
segformer: SegformerModel,
|
||||
decode_head: SegformerDecodeHead,
|
||||
}
|
||||
|
||||
impl SemanticSegmentationModel {
|
||||
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
|
||||
let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?;
|
||||
Ok(Self {
|
||||
segformer,
|
||||
decode_head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SemanticSegmentationModel {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.segformer.forward(x)?;
|
||||
self.decode_head.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ImageClassificationModel {
|
||||
segformer: SegformerModel,
|
||||
classifier: Linear,
|
||||
}
|
||||
|
||||
impl ImageClassificationModel {
|
||||
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let segformer = SegformerModel::new(config, vb.pp("segformer"))?;
|
||||
let classifier = linear(config.decoder_hidden_size, num_labels, vb.pp("classifier"))?;
|
||||
Ok(Self {
|
||||
segformer,
|
||||
classifier,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ImageClassificationModel {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let all_hidden_states = self.segformer.forward(x)?;
|
||||
let hidden_states = all_hidden_states.last().unwrap();
|
||||
let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?;
|
||||
let mean = hidden_states.mean(1)?;
|
||||
self.classifier.forward(&mean)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_config_json_load() {
|
||||
let raw_json = r#"{
|
||||
"architectures": [
|
||||
"SegformerForImageClassification"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.0,
|
||||
"classifier_dropout_prob": 0.1,
|
||||
"decoder_hidden_size": 256,
|
||||
"depths": [
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"downsampling_rates": [
|
||||
1,
|
||||
4,
|
||||
8,
|
||||
16
|
||||
],
|
||||
"drop_path_rate": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.0,
|
||||
"hidden_sizes": [
|
||||
32,
|
||||
64,
|
||||
160,
|
||||
256
|
||||
],
|
||||
"image_size": 224,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"mlp_ratios": [
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4
|
||||
],
|
||||
"model_type": "segformer",
|
||||
"num_attention_heads": [
|
||||
1,
|
||||
2,
|
||||
5,
|
||||
8
|
||||
],
|
||||
"num_channels": 3,
|
||||
"num_encoder_blocks": 4,
|
||||
"patch_sizes": [
|
||||
7,
|
||||
3,
|
||||
3,
|
||||
3
|
||||
],
|
||||
"sr_ratios": [
|
||||
8,
|
||||
4,
|
||||
2,
|
||||
1
|
||||
],
|
||||
"strides": [
|
||||
4,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.12.0.dev0"
|
||||
}"#;
|
||||
let config: Config = serde_json::from_str(raw_json).unwrap();
|
||||
assert_eq!(vec![4, 2, 2, 2], config.strides);
|
||||
assert_eq!(1e-6, config.layer_norm_eps);
|
||||
}
|
||||
}
|
@ -4,7 +4,7 @@ use candle_nn::{Activation, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm.py
|
||||
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub(crate) vocab_size: usize,
|
||||
@ -14,10 +14,10 @@ pub struct Config {
|
||||
pub(crate) num_attention_heads: usize,
|
||||
pub(crate) num_key_value_heads: usize,
|
||||
pub(crate) hidden_act: Activation,
|
||||
pub(crate) partial_rotary_factor: f64,
|
||||
pub(crate) rope_pct: f64,
|
||||
pub(crate) rope_theta: f64,
|
||||
pub(crate) max_position_embeddings: usize,
|
||||
pub(crate) layer_norm_eps: f64,
|
||||
pub(crate) norm_eps: f64,
|
||||
pub(crate) use_cache: bool,
|
||||
#[serde(default)]
|
||||
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
|
||||
@ -35,10 +35,10 @@ impl Config {
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 32,
|
||||
hidden_act: Activation::Silu,
|
||||
partial_rotary_factor: 0.25,
|
||||
rope_pct: 0.25,
|
||||
rope_theta: 10_000.,
|
||||
max_position_embeddings: 4096,
|
||||
layer_norm_eps: 1e-5,
|
||||
norm_eps: 1e-5,
|
||||
use_qkv_bias: false,
|
||||
use_cache: true,
|
||||
use_flash_attn,
|
||||
@ -50,7 +50,7 @@ impl Config {
|
||||
}
|
||||
|
||||
pub fn rotary_ndims(&self) -> usize {
|
||||
(self.head_dim() as f64 * self.partial_rotary_factor) as usize
|
||||
(self.head_dim() as f64 * self.rope_pct) as usize
|
||||
}
|
||||
|
||||
pub fn num_kv_groups(&self) -> usize {
|
||||
@ -316,14 +316,11 @@ impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm = candle_nn::layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_eps,
|
||||
vb.pp("input_layernorm"),
|
||||
)?;
|
||||
let input_layernorm =
|
||||
candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = candle_nn::layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_eps,
|
||||
cfg.norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
@ -375,7 +372,7 @@ impl Model {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb_m.pp("norm"))?;
|
||||
let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
|
@ -1,347 +0,0 @@
|
||||
#![allow(unused)]
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, linear_b, LayerNorm, Linear, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
intermediate_size: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
num_key_value_heads: usize,
|
||||
hidden_act: candle_nn::Activation,
|
||||
max_position_embeddings: usize,
|
||||
norm_epsilon: f64,
|
||||
rope_theta: f64,
|
||||
use_bias: bool,
|
||||
sliding_window: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
c_fc: Linear,
|
||||
c_proj: Linear,
|
||||
act: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
|
||||
let c_fc = linear_b(h_size, i_size, cfg.use_bias, vb.pp("c_fc"))?;
|
||||
let c_proj = linear_b(i_size, h_size, cfg.use_bias, vb.pp("c_proj"))?;
|
||||
Ok(Self {
|
||||
c_fc,
|
||||
c_proj,
|
||||
act: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.c_fc)?.apply(&self.act)?.apply(&self.c_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = hidden_sz / num_heads;
|
||||
let b = cfg.use_bias;
|
||||
let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size: hidden_sz,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&value_states)?;
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: LayerNorm,
|
||||
post_attention_layernorm: LayerNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
layer_norm(cfg.hidden_size, cfg.norm_epsilon, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.norm_epsilon,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: LayerNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: Option<usize>,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = layer_norm(cfg.hidden_size, cfg.norm_epsilon, vb_m.pp("norm"))?;
|
||||
let lm_head = candle_nn::Linear::new(embed_tokens.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
sliding_window: cfg.sliding_window,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let sliding_window = self.sliding_window.unwrap_or(tgt_len + 42);
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
156
candle-transformers/src/models/vocos.rs
Normal file
156
candle-transformers/src/models/vocos.rs
Normal file
@ -0,0 +1,156 @@
|
||||
#![allow(unused)]
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
|
||||
|
||||
pub struct AdaLayerNorm {
|
||||
eps: f64,
|
||||
dim: usize,
|
||||
scale: Embedding,
|
||||
shift: Embedding,
|
||||
}
|
||||
|
||||
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let x = {
|
||||
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
x.broadcast_sub(&mean_x)?
|
||||
};
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
|
||||
x_normed.to_dtype(x_dtype)
|
||||
}
|
||||
|
||||
impl AdaLayerNorm {
|
||||
pub fn new(
|
||||
num_embeddings: usize,
|
||||
embedding_dim: usize,
|
||||
eps: f64,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
|
||||
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
|
||||
Ok(Self {
|
||||
eps,
|
||||
dim: embedding_dim,
|
||||
scale,
|
||||
shift,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
|
||||
let scale = self.scale.forward(cond_embedding_id)?;
|
||||
let shift = self.shift.forward(cond_embedding_id)?;
|
||||
let xs = layer_norm(xs, self.eps)?;
|
||||
xs * scale + shift
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ConvNeXtBlock {
|
||||
dwconv: Conv1d,
|
||||
pwconv1: Linear,
|
||||
pwconv2: Linear,
|
||||
gamma: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl ConvNeXtBlock {
|
||||
pub fn new(
|
||||
dim: usize,
|
||||
intermediate_dim: usize,
|
||||
layer_scale_init_value: f64,
|
||||
adanorm_num_embeddings: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let dwconv = {
|
||||
let cfg = Conv1dConfig {
|
||||
padding: 3,
|
||||
groups: dim,
|
||||
..Default::default()
|
||||
};
|
||||
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
|
||||
};
|
||||
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
|
||||
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
|
||||
let gamma = if layer_scale_init_value > 0. {
|
||||
Some(vb.get(dim, "gamma")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
dwconv,
|
||||
pwconv1,
|
||||
pwconv2,
|
||||
gamma,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
|
||||
// TODO: norm
|
||||
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
|
||||
let xs = match self.gamma.as_ref() {
|
||||
Some(gamma) => (gamma * xs)?,
|
||||
None => xs,
|
||||
};
|
||||
xs.transpose(1, 2)? + residual
|
||||
}
|
||||
}
|
||||
|
||||
struct VocosBackbone {
|
||||
embed: Conv1d,
|
||||
convnext: Vec<ConvNeXtBlock>,
|
||||
final_layer_norm: candle_nn::LayerNorm,
|
||||
}
|
||||
|
||||
impl VocosBackbone {
|
||||
pub fn new(
|
||||
input_channels: usize,
|
||||
dim: usize,
|
||||
intermediate_dim: usize,
|
||||
num_layers: dim,
|
||||
layer_scale_init_value: f64,
|
||||
adanorm_num_embeddings: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let embed = {
|
||||
let cfg = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
|
||||
};
|
||||
let mut convnext = Vec::with_capacity(num_layers);
|
||||
let vb_c = vb.pp("convnext");
|
||||
for i in 0..num_layers {
|
||||
let block = ConvNeXtBlock::new(
|
||||
dim,
|
||||
intermediate_dim,
|
||||
layer_scale_init_value,
|
||||
adanorm_num_embeddings,
|
||||
vb_c.pp(i),
|
||||
)?;
|
||||
}
|
||||
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
embed,
|
||||
convnext,
|
||||
final_layer_norm,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embed)?;
|
||||
// TODO: norm
|
||||
let mut xs = xs.transpose(1, 2)?;
|
||||
for conv_block in self.convnext.iter() {
|
||||
xs = conv_block.forward(&xs)?
|
||||
}
|
||||
xs.apply(&self.final_layer_norm)
|
||||
}
|
||||
}
|
@ -167,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
|
||||
mel
|
||||
}
|
||||
|
||||
pub fn log_mel_spectrogram_<T: Float>(
|
||||
fn log_mel_spectrogram_<T: Float>(
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
fft_size: usize,
|
||||
|
@ -47,12 +47,6 @@ impl Linear {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn linear_b(d1: usize, d2: usize, b: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let inner = candle_nn::linear_b(d1, d2, b, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let inner = candle_nn::linear(d1, d2, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
|
Reference in New Issue
Block a user