mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
39 Commits
copy-multi
...
llama-v3-m
Author | SHA1 | Date | |
---|---|---|---|
6d6d87f8b3 | |||
9c532aef47 | |||
f7a6468238 | |||
2b93dffe64 | |||
e6ee7ba4d4 | |||
1690ab45d2 | |||
8de0ce6cba | |||
ce6d08df94 | |||
2817643db9 | |||
4d14777673 | |||
f135b7963d | |||
af955f260c | |||
8ad822a983 | |||
e198bb0816 | |||
f7d5bf5b97 | |||
c119600d6e | |||
c449f65b12 | |||
db7dbf3071 | |||
4ecedb1598 | |||
53e5380bf6 | |||
50e49ecc5f | |||
4c88c3ce06 | |||
8b8fb630df | |||
fb805b8ca2 | |||
79e3bec789 | |||
e6d412b156 | |||
26cbbf8d84 | |||
2bf413caa3 | |||
3ad4770eb6 | |||
a0460cd2b1 | |||
b81ecf712d | |||
a4d5a414e3 | |||
798e0335cd | |||
718671a0d5 | |||
c5fe4a7f89 | |||
7f354473cf | |||
33c9b66554 | |||
9fd52b3b71 | |||
e662431acf |
@ -63,8 +63,9 @@ We also provide a some command line based examples using state of the art models
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
|
||||
Deepmind.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||
Griffin based models from Google that mix attention with a RNN like state.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
@ -374,9 +375,9 @@ git submodule update --init
|
||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||
```
|
||||
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
|
||||
```
|
||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
```
|
||||
|
||||
#### Linking error on windows when running rustdoc or mdbook tests
|
||||
|
@ -7,4 +7,5 @@ criterion_main!(
|
||||
benchmarks::random::benches,
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
);
|
||||
|
@ -1,6 +1,7 @@
|
||||
pub(crate) mod affine;
|
||||
pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod qmatmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
|
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
72
candle-core/benches/benchmarks/qmatmul.rs
Normal file
@ -0,0 +1,72 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{
|
||||
quantized::{self, GgmlDType, QMatMul},
|
||||
Device, Module, Tensor,
|
||||
};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(matmul: &QMatMul, x: &Tensor) {
|
||||
matmul.forward(&x).unwrap();
|
||||
}
|
||||
|
||||
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
||||
let b = 1;
|
||||
let m = 1;
|
||||
let n = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let lhs = (0..(m * k))
|
||||
.map(|v| v as f32 / (m * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..(k * n))
|
||||
.map(|v| v as f32 / (n * k) as f32)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_slice(&lhs, (m, k), device).unwrap();
|
||||
let rhs = Tensor::from_slice(&rhs, (k, n), device).unwrap();
|
||||
|
||||
let qtensor = quantized::QTensor::quantize(&rhs.t().unwrap(), dtype).unwrap();
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor).unwrap();
|
||||
|
||||
let flops = b * m * n * k;
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(format!("qmatmul_{:?}", dtype)));
|
||||
group.sample_size(200);
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&matmul), black_box(&lhs));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
for dtype in vec![
|
||||
GgmlDType::F32,
|
||||
GgmlDType::F16,
|
||||
GgmlDType::Q4_0,
|
||||
GgmlDType::Q4_1,
|
||||
GgmlDType::Q5_0,
|
||||
GgmlDType::Q5_1,
|
||||
GgmlDType::Q8_0,
|
||||
GgmlDType::Q2K,
|
||||
GgmlDType::Q3K,
|
||||
GgmlDType::Q4K,
|
||||
GgmlDType::Q5K,
|
||||
GgmlDType::Q6K,
|
||||
] {
|
||||
run_bench(c, &device, dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -142,4 +142,7 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage>;
|
||||
|
||||
fn set_seed(&self, _: u64) -> Result<()>;
|
||||
|
||||
/// Synchronize should block until all the operations on the device are completed.
|
||||
fn synchronize(&self) -> Result<()>;
|
||||
}
|
||||
|
@ -624,7 +624,7 @@ impl Tensor {
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
|
@ -2628,6 +2628,10 @@ impl BackendDevice for CpuDevice {
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
|
@ -407,4 +407,9 @@ impl BackendDevice for CudaDevice {
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -337,4 +337,12 @@ impl Device {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn synchronize(&self) -> Result<()> {
|
||||
match self {
|
||||
Self::Cpu => Ok(()),
|
||||
Self::Cuda(d) => d.synchronize(),
|
||||
Self::Metal(d) => d.synchronize(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -229,4 +229,8 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -241,4 +241,8 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -283,5 +283,5 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
(size - 1).next_power_of_two() as NSUInteger
|
||||
size.saturating_sub(1).next_power_of_two() as NSUInteger
|
||||
}
|
||||
|
@ -2,8 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels::CallConvTranspose2dCfg;
|
||||
use candle_metal_kernels::Kernels;
|
||||
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
@ -12,6 +11,12 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
mod device;
|
||||
pub use device::{DeviceId, MetalDevice};
|
||||
|
||||
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||
BufferOffset {
|
||||
buffer,
|
||||
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
|
||||
}
|
||||
}
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -102,7 +107,8 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "affine")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, dtype);
|
||||
if layout.is_contiguous() {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "affine_f32",
|
||||
DType::F16 => "affine_f16",
|
||||
@ -115,7 +121,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -134,9 +140,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -155,7 +160,8 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "powf")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, dtype);
|
||||
if layout.is_contiguous() {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32",
|
||||
DType::F16 => "powf_f16",
|
||||
@ -168,7 +174,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
pow as f32,
|
||||
)
|
||||
@ -186,9 +192,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
pow as f32,
|
||||
)
|
||||
@ -206,7 +211,8 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let buffer = device.new_buffer(el, self.dtype, "elu")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32",
|
||||
DType::F16 => "elu_f16",
|
||||
@ -219,7 +225,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
el,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
alpha as f32,
|
||||
)
|
||||
@ -237,9 +243,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
alpha as f32,
|
||||
)
|
||||
@ -309,6 +314,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_reduce_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -317,8 +323,7 @@ impl BackendStorage for MetalStorage {
|
||||
&dims,
|
||||
&stride,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -344,7 +349,8 @@ impl BackendStorage for MetalStorage {
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
@ -392,8 +398,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -420,9 +425,8 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -439,7 +443,8 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label(B::KERNEL);
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
@ -511,7 +516,7 @@ impl BackendStorage for MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -535,6 +540,7 @@ impl BackendStorage for MetalStorage {
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
@ -552,21 +558,39 @@ impl BackendStorage for MetalStorage {
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
|
||||
("ucos", DType::BF16) => strided::cos::BFLOAT,
|
||||
("usin", DType::BF16) => strided::sin::BFLOAT,
|
||||
("usqr", DType::BF16) => strided::sqr::BFLOAT,
|
||||
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
|
||||
("uneg", DType::BF16) => strided::neg::BFLOAT,
|
||||
("uexp", DType::BF16) => strided::exp::BFLOAT,
|
||||
("ulog", DType::BF16) => strided::log::BFLOAT,
|
||||
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
|
||||
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
|
||||
("uerf", DType::BF16) => strided::erf::BFLOAT,
|
||||
("usilu", DType::BF16) => strided::silu::BFLOAT,
|
||||
("uabs", DType::BF16) => strided::abs::BFLOAT,
|
||||
("uceil", DType::BF16) => strided::ceil::BFLOAT,
|
||||
("ufloor", DType::BF16) => strided::floor::BFLOAT,
|
||||
("urelu", DType::BF16) => strided::relu::BFLOAT,
|
||||
("uround", DType::BF16) => strided::round::BFLOAT,
|
||||
("utanh", DType::BF16) => strided::tanh::BFLOAT,
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
0,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
@ -613,21 +637,21 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U8, DType::U8) => "where_u8_u8",
|
||||
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
let t = buffer_o(&t.buffer, t_l, t.dtype);
|
||||
let f = buffer_o(&f.buffer, f_l, f.dtype);
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
name,
|
||||
dims,
|
||||
&self.buffer,
|
||||
(
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
src,
|
||||
layout.stride(),
|
||||
t,
|
||||
t_l.stride(),
|
||||
f,
|
||||
f_l.stride(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -660,6 +684,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::F32 => "im2col1d_f32",
|
||||
dtype => crate::bail!("Metal conv1d {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_im2col1d_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -668,8 +693,7 @@ impl BackendStorage for MetalStorage {
|
||||
layout.shape().dims(),
|
||||
strides,
|
||||
(k_size, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -787,6 +811,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::U32 => "im2col_u32",
|
||||
dtype => crate::bail!("Metal conv2d {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
candle_metal_kernels::call_im2col_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -795,8 +820,7 @@ impl BackendStorage for MetalStorage {
|
||||
layout.shape().dims(),
|
||||
layout.stride(),
|
||||
(h_k, w_k, stride, padding, dilation),
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1009,6 +1033,7 @@ impl BackendStorage for MetalStorage {
|
||||
.device
|
||||
.new_buffer(dst_el, self.dtype, "upsample_nearest2d")?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, inp_l, self.dtype);
|
||||
candle_metal_kernels::call_upsample_nearest_2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1018,8 +1043,7 @@ impl BackendStorage for MetalStorage {
|
||||
strides,
|
||||
out_w,
|
||||
out_h,
|
||||
&self.buffer,
|
||||
inp_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1027,9 +1051,8 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "gather" }.bt())?,
|
||||
if !ids_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "gather" }.bt());
|
||||
};
|
||||
let ids_el = ids_l.dims()[dim];
|
||||
let dst_el = ids_l.shape().elem_count();
|
||||
@ -1039,9 +1062,12 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "gather_u32_f32",
|
||||
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||
(DType::U32, DType::BF16) => "gather_u32_bf16",
|
||||
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, src_l, dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_gather(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -1050,10 +1076,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
ids_el,
|
||||
dim,
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_o1 * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1071,13 +1095,8 @@ impl BackendStorage for MetalStorage {
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "sa_u8_f32",
|
||||
@ -1096,6 +1115,8 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1104,10 +1125,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1143,6 +1162,8 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&self.buffer, src_l, dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_index_select(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
@ -1154,10 +1175,8 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.is_contiguous(),
|
||||
src_l.dims(),
|
||||
src_l.stride(),
|
||||
&self.buffer,
|
||||
src_l.start_offset() * dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_l.start_offset() * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1175,13 +1194,8 @@ impl BackendStorage for MetalStorage {
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
let (ids_offset, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src_offset = match src_l.contiguous_offsets() {
|
||||
Some((o1, _)) => o1,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::I64, DType::BF16) => "ia_i64_bf16",
|
||||
@ -1212,6 +1226,8 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_index_add(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
@ -1221,10 +1237,8 @@ impl BackendStorage for MetalStorage {
|
||||
l.dims(),
|
||||
ids_l.dims(),
|
||||
dim,
|
||||
&src.buffer,
|
||||
src_offset * src.dtype.size_in_bytes(),
|
||||
&ids.buffer,
|
||||
ids_offset * ids.dtype.size_in_bytes(),
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1358,17 +1372,20 @@ impl BackendStorage for MetalStorage {
|
||||
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
||||
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
|
||||
};
|
||||
let src = buffer_o(&self.buffer, src_l, self.dtype);
|
||||
let dst = BufferOffset {
|
||||
buffer: &dst.buffer,
|
||||
offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
src_l.dims(),
|
||||
&self.buffer,
|
||||
src,
|
||||
src_l.stride(),
|
||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&dst.buffer,
|
||||
dst_offset * dst.dtype.size_in_bytes(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy_strided");
|
||||
@ -1402,10 +1419,9 @@ impl MetalStorage {
|
||||
let shape = lhs_l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
&& &op[..1] != "b"
|
||||
{
|
||||
let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
|
||||
let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
|
||||
let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||
@ -1486,8 +1502,8 @@ impl MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
&self.buffer,
|
||||
&rhs.buffer,
|
||||
lhs,
|
||||
rhs,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1585,12 +1601,10 @@ impl MetalStorage {
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
lhs,
|
||||
lhs_l.stride(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
rhs,
|
||||
rhs_l.stride(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1796,6 +1810,10 @@ impl BackendDevice for MetalDevice {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
self.wait_until_completed()
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
|
@ -40,6 +40,7 @@ fn quantize_q8_1(
|
||||
src: &CudaView<f32>,
|
||||
dst: &mut CudaSlice<u8>,
|
||||
elem_count: usize,
|
||||
ky: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
@ -49,7 +50,7 @@ fn quantize_q8_1(
|
||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
@ -165,6 +166,7 @@ fn mul_mat_vec_via_q8_1(
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
b_size: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
@ -173,14 +175,18 @@ fn mul_mat_vec_via_q8_1(
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
}
|
||||
if y.len() != ncols {
|
||||
if y.len() != ncols * b_size {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
if b_size == 0 || b_size > 8 {
|
||||
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
|
||||
}
|
||||
// Start by quantizing y
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes = ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let y_size_in_bytes =
|
||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, dev)?;
|
||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "mul_mat_vec_q4_0_q8_1_cuda",
|
||||
@ -195,11 +201,19 @@ fn mul_mat_vec_via_q8_1(
|
||||
GgmlDType::Q6K => "mul_mat_vec_q6_K_q8_1_cuda",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||
let (nblocks, nwarps) = match b_size {
|
||||
1 => (nrows as u32, 4),
|
||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nrows as u32, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
grid_dim: (nblocks, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, nwarps, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
@ -209,13 +223,83 @@ fn mul_mat_vec_via_q8_1(
|
||||
&dst,
|
||||
/* ncols_x */ ncols as i32,
|
||||
/* nrows_x */ nrows as i32,
|
||||
/* nrows_y */ ncols as i32,
|
||||
/* nrows_y */ ncols_padded as i32,
|
||||
/* nrows_dst */ nrows as i32,
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn mul_mat_via_q8_1(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
x_rows: usize,
|
||||
x_cols: usize,
|
||||
y_rows: usize,
|
||||
y_cols: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len() / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < x_rows * x_cols {
|
||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||
}
|
||||
if y.len() != y_rows * y_cols {
|
||||
crate::bail!("unexpected y size {}, {y_rows} {y_cols}", y.len())
|
||||
}
|
||||
if x_cols != y_rows {
|
||||
crate::bail!("unexpected x/y size {x_rows} {x_cols} {y_rows} {y_cols}")
|
||||
}
|
||||
let k = x_cols;
|
||||
// Start by quantizing y
|
||||
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
||||
let y_size_in_bytes =
|
||||
k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
||||
|
||||
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
||||
GgmlDType::Q4_0 => ("mul_mat_q4_0", 64, 128),
|
||||
GgmlDType::Q4_1 => ("mul_mat_q4_1", 64, 128),
|
||||
GgmlDType::Q5_0 => ("mul_mat_q5_0", 128, 64),
|
||||
GgmlDType::Q5_1 => ("mul_mat_q5_1", 128, 64),
|
||||
GgmlDType::Q8_0 => ("mul_mat_q8_0", 128, 64),
|
||||
GgmlDType::Q2K => ("mul_mat_q2_K", 64, 128),
|
||||
GgmlDType::Q3K => ("mul_mat_q3_K", 128, 128),
|
||||
GgmlDType::Q4K => ("mul_mat_q4_K", 64, 128),
|
||||
GgmlDType::Q5K => ("mul_mat_q5_K", 64, 128),
|
||||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (
|
||||
ceil_div(x_rows, mmq_y) as u32,
|
||||
ceil_div(y_cols, mmq_x) as u32,
|
||||
1,
|
||||
),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (
|
||||
/* vx */ data,
|
||||
/* vy */ &y_q8_1,
|
||||
/* dst */ &dst,
|
||||
/* ncols_x */ x_cols as i32,
|
||||
/* nrows_x */ x_rows as i32,
|
||||
/* ncols_y */ y_cols as i32,
|
||||
/* nrows_y */ k_padded as i32,
|
||||
/* nrows_dst */ x_rows as i32,
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||
@ -313,7 +397,17 @@ impl QCudaStorage {
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
1
|
||||
} else {
|
||||
8
|
||||
};
|
||||
let use_vec_kernel = match layout.shape().dims() {
|
||||
[b, m, _k] => b * m <= max_bm,
|
||||
[b, _k] => *b <= max_bm,
|
||||
_ => false,
|
||||
};
|
||||
if use_vec_kernel {
|
||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||
} else {
|
||||
self.dequantize_matmul(self_shape, storage, layout)
|
||||
@ -334,25 +428,31 @@ impl QCudaStorage {
|
||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||
};
|
||||
let (with_batch, k) = match rhs_l.shape().dims() {
|
||||
[1, 1, k] => (true, k),
|
||||
[1, k] => (false, k),
|
||||
let (b_size, k) = match rhs_l.shape().dims() {
|
||||
[b, m, k] => (b * m, *k),
|
||||
[b, k] => (*b, *k),
|
||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||
};
|
||||
if ncols != *k {
|
||||
if ncols != k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||
}
|
||||
|
||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
dequantize_mul_mat_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
} else {
|
||||
mul_mat_vec_via_q8_1(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?
|
||||
};
|
||||
let out_shape = if with_batch {
|
||||
vec![1, 1, nrows]
|
||||
} else {
|
||||
vec![1, nrows]
|
||||
mul_mat_vec_via_q8_1(
|
||||
&self.data,
|
||||
&rhs,
|
||||
self.dtype,
|
||||
ncols,
|
||||
nrows,
|
||||
b_size,
|
||||
self.device(),
|
||||
)?
|
||||
};
|
||||
let mut out_shape = rhs_l.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(nrows);
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
|
||||
@ -373,9 +473,30 @@ impl QCudaStorage {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
||||
}
|
||||
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
||||
let out = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||
storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?
|
||||
} else {
|
||||
let storage = storage.as_cuda_slice::<f32>()?;
|
||||
let storage = match layout.contiguous_offsets() {
|
||||
Some((o1, o2)) => storage.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous {
|
||||
op: "quantized-matmul",
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
mul_mat_via_q8_1(
|
||||
&self.data,
|
||||
&storage,
|
||||
self.dtype,
|
||||
/* x_rows */ n,
|
||||
/* x_cols */ k,
|
||||
/* y_rows */ k,
|
||||
/* y_cols */ b * m,
|
||||
self.device(),
|
||||
)?
|
||||
};
|
||||
let mut out_shape = layout.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(n);
|
||||
@ -412,7 +533,7 @@ mod test {
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, &dev)?;
|
||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -430,6 +551,7 @@ mod test {
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* ncols */ ncols,
|
||||
/* nrows */ 1,
|
||||
/* b_size */ 1,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
@ -453,4 +575,44 @@ mod test {
|
||||
assert_eq!(vs[0], 5561851.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cuda_mm_q8_1() -> Result<()> {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
&xs.data,
|
||||
&y.slice(..),
|
||||
/* dtype */ GgmlDType::Q4_0,
|
||||
/* x_rows */ 4,
|
||||
/* x_cols */ ncols,
|
||||
/* y_rows */ ncols,
|
||||
/* y_cols */ 4,
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
|
||||
/*
|
||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||
x @ x.t() / 16
|
||||
tensor([[ 347480.0000, 869720.0000, 1391960.0000, 1914200.0000],
|
||||
[ 869720.0000, 2440536.0000, 4011352.0000, 5582166.5000],
|
||||
[ 1391960.0000, 4011352.0000, 6630742.0000, 9250132.0000],
|
||||
[ 1914200.0000, 5582166.5000, 9250132.0000, 12918099.0000]])
|
||||
*/
|
||||
assert_eq!(vs.len(), 16);
|
||||
assert_eq!(vs[0], 347604.0);
|
||||
assert_eq!(vs[1], 888153.06);
|
||||
assert_eq!(vs[4], 869780.7);
|
||||
assert_eq!(vs[5], 2483145.0);
|
||||
assert_eq!(vs[11], 9407368.0);
|
||||
assert_eq!(vs[14], 9470856.0);
|
||||
assert_eq!(vs[15], 13138824.0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -149,8 +149,11 @@ impl QMetalStorage {
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
|
||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||
// properly.
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (dst_shape[0], dst_shape[1]),
|
||||
3 => (1, dst_shape[0] * dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
|
@ -79,6 +79,9 @@ macro_rules! unary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
if shape.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self
|
||||
.storage()
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
@ -92,6 +95,9 @@ macro_rules! binary_op {
|
||||
($fn_name:ident, $op_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
if shape.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
@ -114,6 +120,9 @@ macro_rules! binary_op_scalar {
|
||||
.broadcast_as(self.shape())?,
|
||||
};
|
||||
let shape = self.same_shape_binary_op(&rhs, stringify!($fn_name))?;
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().binary_impl::<crate::op::$op_name>(
|
||||
&*rhs.storage(),
|
||||
self.layout(),
|
||||
@ -646,6 +655,9 @@ impl Tensor {
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -653,6 +665,9 @@ impl Tensor {
|
||||
|
||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -660,6 +675,9 @@ impl Tensor {
|
||||
|
||||
/// Raise the tensor to some float exponent `e`.
|
||||
pub fn powf(&self, e: f64) -> Result<Self> {
|
||||
if self.elem_count() == 0 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let storage = self.storage().powf(self.layout(), e)?;
|
||||
let op = BackpropOp::new1(self, |t| Op::Powf(t, e));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
@ -1154,6 +1172,9 @@ impl Tensor {
|
||||
let n = b_dims[dim - 1];
|
||||
|
||||
let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]);
|
||||
if c_shape.elem_count() == 0 || k == 0 {
|
||||
return Tensor::zeros(c_shape, self.dtype(), self.device());
|
||||
}
|
||||
let batching: usize = a_dims[..dim - 2].iter().product();
|
||||
let batching_b: usize = b_dims[..dim - 2].iter().product();
|
||||
if k != k2 || batching != batching_b {
|
||||
|
@ -3,7 +3,7 @@ use candle_core::{
|
||||
quantized::{self, GgmlDType},
|
||||
test_device,
|
||||
test_utils::to_vec2_round,
|
||||
Device, Module, Result, Tensor,
|
||||
Device, IndexOp, Module, Result, Tensor,
|
||||
};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
@ -47,18 +47,14 @@ fn test_matmul(
|
||||
}
|
||||
|
||||
fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||
&[
|
||||
@ -67,7 +63,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
let mm = lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
mm.to_vec2::<f32>()?,
|
||||
&[
|
||||
@ -79,7 +75,7 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
let res = matmul.forward(&lhs)?;
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -89,7 +85,15 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
[341970.0, 994574.0, 1656181.0, 2302182.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
Device::Cuda(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[84866.0, 214045.0, 344676.0, 473707.0],
|
||||
[213425.0, 604313.0, 1000431.0, 1387960.0],
|
||||
[342030.0, 994630.0, 1656248.0, 2302250.0]
|
||||
]
|
||||
),
|
||||
Device::Cpu => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[85120.0, 214562.0, 345455.0, 474748.0],
|
||||
@ -98,22 +102,16 @@ fn quantized_matmul(device: &Device) -> Result<()> {
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let (m, k, n) = (3, 64, 4);
|
||||
let lhs = (0..(m * k))
|
||||
let lhs_s = (0..(m * k))
|
||||
.map(|v| v as f32 - (m * k) as f32 / 2.0)
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?;
|
||||
let lhs = Tensor::from_slice(&lhs_s, (m, k), device)?;
|
||||
let mut dst = vec![42.; 3 * 4];
|
||||
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
|
||||
let rhs = (0..k * n)
|
||||
@ -121,7 +119,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
.collect::<Vec<_>>();
|
||||
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
|
||||
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
|
||||
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
|
||||
k_quants::matmul((m, k, n), &lhs_s, &rhs_t, &mut dst)?;
|
||||
assert_eq!(
|
||||
dst.iter().map(|x| x.round()).collect::<Vec<_>>(),
|
||||
&[
|
||||
@ -129,7 +127,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
-196472.0, 63012.0, 324585.0, 587902.0
|
||||
]
|
||||
);
|
||||
let mm = tensor_lhs.matmul(&tensor_rhs)?;
|
||||
let mm = lhs.matmul(&tensor_rhs)?;
|
||||
assert_eq!(
|
||||
to_vec2_round(&mm, 0)?,
|
||||
&[
|
||||
@ -141,7 +139,7 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
|
||||
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?;
|
||||
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let res = matmul.forward(&tensor_lhs)?;
|
||||
let res = matmul.forward(&lhs)?;
|
||||
match device {
|
||||
Device::Metal(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
@ -151,7 +149,15 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
[-196102.0, 63022.0, 324233.0, 587191.0]
|
||||
]
|
||||
),
|
||||
_ => assert_eq!(
|
||||
Device::Cuda(_) => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243740.0, -19762.0, -285476.0, -550498.0],
|
||||
[23774.0, 21645.0, 19395.0, 18364.0],
|
||||
[-196045.0, 63030.0, 324120.0, 587079.0]
|
||||
]
|
||||
),
|
||||
Device::Cpu => assert_eq!(
|
||||
to_vec2_round(&res, 0)?,
|
||||
&[
|
||||
[243524.0, -19596.0, -285051.0, -549815.0],
|
||||
@ -160,22 +166,58 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
|
||||
let res2 = matmul.forward(&lhs2)?;
|
||||
let res2 = res2.i(1)?;
|
||||
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if device.is_cuda() {
|
||||
assert!(diff < 0.1);
|
||||
} else {
|
||||
assert_eq!(diff, 0.);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
quantized_matmul,
|
||||
quantized_matmul_cpu,
|
||||
quantized_matmul_cuda,
|
||||
quantized_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
quantized_matmul_neg,
|
||||
quantized_matmul_neg_cpu,
|
||||
quantized_matmul_neg_cuda,
|
||||
quantized_matmul_neg_metal
|
||||
);
|
||||
fn qmm_batch(dev: &Device) -> Result<()> {
|
||||
let (lhs, rhs, _mm) = get_random_tensors(2, 256, 6, dev)?;
|
||||
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?;
|
||||
let rhs = quantized::QMatMul::from_qtensor(rhs)?;
|
||||
let mm = rhs.forward(&lhs)?;
|
||||
assert_eq!(mm.shape().dims(), [2, 6]);
|
||||
let lhs2 = Tensor::cat(&[&lhs, &lhs], 0)?;
|
||||
let mm2 = rhs.forward(&lhs2)?;
|
||||
assert_eq!(mm2.shape().dims(), [4, 6]);
|
||||
let diff2 = (mm2.i(2..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff2, 0.0);
|
||||
let lhs3 = Tensor::cat(&[&lhs2, &lhs], 0)?;
|
||||
let mm3 = rhs.forward(&lhs3)?;
|
||||
assert_eq!(mm3.shape().dims(), [6, 6]);
|
||||
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff3, 0.0);
|
||||
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff3, 0.0);
|
||||
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
|
||||
let mm4 = rhs.forward(&lhs4)?;
|
||||
assert_eq!(mm4.shape().dims(), [12, 6]);
|
||||
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
// We use a different kernel for sizes from 1 to 8 on cuda which explains
|
||||
// the difference here.
|
||||
assert!(0. < diff4 && diff4 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff4, 0.0)
|
||||
};
|
||||
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff4, 0.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(quantized_matmul, qmm_cpu, qmm_cuda, qmm_metal);
|
||||
test_device!(quantized_matmul_neg, qmm_n_cpu, qmm_n_cuda, qmm_n_metal);
|
||||
test_device!(qmm_batch, qmm_b_cpu, qmm_b_cuda, qmm_b_metal);
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
|
@ -1083,6 +1083,27 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn zero_dim(device: &Device) -> Result<()> {
|
||||
let t = Tensor::zeros((4, 0, 1), DType::F32, device)?;
|
||||
assert_eq!(t.dims3()?, (4, 0, 1));
|
||||
let t2 = Tensor::zeros((4, 3, 1), DType::F32, device)?;
|
||||
let t_cat = Tensor::cat(&[&t, &t2], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 3, 1));
|
||||
let t_cat = Tensor::cat(&[&t, &t], 1)?;
|
||||
assert_eq!(t_cat.dims3()?, (4, 0, 1));
|
||||
let t_unary = t.sqrt()?;
|
||||
assert_eq!(t_unary.dims3()?, (4, 0, 1));
|
||||
let t_plus = (&t + 1.)?;
|
||||
assert_eq!(t_plus.dims3()?, (4, 0, 1));
|
||||
let t_mm = t2.matmul(&t.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 3, 0));
|
||||
let t_mm = t.matmul(&t2.t()?)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 0, 3));
|
||||
let t_mm = t.t()?.matmul(&t)?;
|
||||
assert_eq!(t_mm.dims3()?, (4, 1, 1));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
@ -1131,6 +1152,7 @@ test_device!(
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
|
@ -16,6 +16,30 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "7b")]
|
||||
Base7B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
#[value(name = "7b-it")]
|
||||
Instruct7B,
|
||||
#[value(name = "1.1-2b-it")]
|
||||
InstructV1_1_2B,
|
||||
#[value(name = "1.1-7b-it")]
|
||||
InstructV1_1_7B,
|
||||
#[value(name = "code-2b")]
|
||||
CodeBase2B,
|
||||
#[value(name = "code-7b")]
|
||||
CodeBase7B,
|
||||
#[value(name = "code-2b-it")]
|
||||
CodeInstruct2B,
|
||||
#[value(name = "code-7b-it")]
|
||||
CodeInstruct7B,
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
@ -165,6 +189,10 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -196,14 +224,19 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => match model_id.as_str() {
|
||||
"7b-it" => "google/gemma-7b-it".to_string(),
|
||||
"7b" => "google/gemma-7b".to_string(),
|
||||
"2b-it" => "google/gemma-2b-it".to_string(),
|
||||
"2b" => "google/gemma-2b".to_string(),
|
||||
_ => model_id.to_string(),
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.which {
|
||||
Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(),
|
||||
Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(),
|
||||
Which::Base2B => "google/gemma-2b".to_string(),
|
||||
Which::Base7B => "google/gemma-7b".to_string(),
|
||||
Which::Instruct2B => "google/gemma-2b-it".to_string(),
|
||||
Which::Instruct7B => "google/gemma-7b-it".to_string(),
|
||||
Which::CodeBase2B => "google/codegemma-2b".to_string(),
|
||||
Which::CodeBase7B => "google/codegemma-7b".to_string(),
|
||||
Which::CodeInstruct2B => "google/codegemma-2b-it".to_string(),
|
||||
Which::CodeInstruct7B => "google/codegemma-7b-it".to_string(),
|
||||
},
|
||||
None => "google/gemma-2b".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
|
@ -31,6 +31,8 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
enum Which {
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V3Instruct,
|
||||
#[value(name = "solar-10.7b")]
|
||||
Solar10_7B,
|
||||
#[value(name = "tiny-llama-1.1b-chat")]
|
||||
@ -45,8 +47,8 @@ struct Args {
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
@ -90,11 +92,11 @@ struct Args {
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
#[arg(long, default_value_t = 128)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
@ -118,13 +120,18 @@ fn main() -> Result<()> {
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
None => match args.which {
|
||||
Which::V3 | Which::V3Instruct => DType::BF16,
|
||||
Which::V1 | Which::V2 | Which::Solar10_7B | Which::TinyLlama1_1BChat => DType::F16,
|
||||
},
|
||||
};
|
||||
let (llama, tokenizer_filename, mut cache) = {
|
||||
let (llama, tokenizer_filename, mut cache, config) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||
});
|
||||
@ -138,7 +145,7 @@ fn main() -> Result<()> {
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let filenames = match args.which {
|
||||
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
@ -146,10 +153,12 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
let eos_token_id = config
|
||||
.eos_token_id
|
||||
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -160,7 +169,7 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), args.top_p);
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
|
@ -54,6 +54,7 @@ impl TextGeneration {
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let dtype = self.model.dtype();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
@ -66,7 +67,7 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the </s> token"),
|
||||
};
|
||||
let mut state = State::new(1, &self.config, &self.device)?;
|
||||
let mut state = State::new(1, &self.config, dtype, &self.device)?;
|
||||
let mut next_logits = None;
|
||||
for &t in tokens.iter() {
|
||||
let input = Tensor::new(&[t], &self.device)?;
|
||||
@ -84,7 +85,7 @@ impl TextGeneration {
|
||||
Some(logits) => logits,
|
||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(dtype)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
@ -210,6 +211,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "f32")]
|
||||
dtype: String,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -220,6 +224,7 @@ struct Args {
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use std::str::FromStr;
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
@ -279,7 +284,8 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let dtype = DType::from_str(&args.dtype)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -123,7 +123,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
if next_token == eos_token || tokens.ends_with(&[27, 10619, 29] /* <END> */) {
|
||||
break;
|
||||
}
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
|
@ -67,6 +67,8 @@ enum Which {
|
||||
Mixtral,
|
||||
#[value(name = "mixtral-instruct")]
|
||||
MixtralInstruct,
|
||||
#[value(name = "llama3-8b")]
|
||||
L8b,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -82,7 +84,8 @@ impl Which {
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode
|
||||
| Self::Leo7b
|
||||
| Self::Leo13b => false,
|
||||
| Self::Leo13b
|
||||
| Self::L8b => false,
|
||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||
// same way. Starling is a fine tuned version of OpenChat.
|
||||
Self::OpenChat35
|
||||
@ -116,7 +119,8 @@ impl Which {
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::OpenChat35
|
||||
| Self::Starling7bAlpha => false,
|
||||
| Self::Starling7bAlpha
|
||||
| Self::L8b => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
}
|
||||
@ -140,7 +144,8 @@ impl Which {
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta => false,
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::L8b => false,
|
||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||
}
|
||||
}
|
||||
@ -167,6 +172,7 @@ impl Which {
|
||||
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||
Which::OpenChat35 => "openchat/openchat_3.5",
|
||||
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||
Self::L8b => "meta-llama/Meta-Llama-3-8B",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -322,6 +328,11 @@ impl Args {
|
||||
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
||||
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
||||
),
|
||||
// TODO: swap to TheBloke model when available
|
||||
Which::L8b => (
|
||||
"QuantFactory/Meta-Llama-3-8B-GGUF",
|
||||
"Meta-Llama-3-8B.Q4_K_S.gguf",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(repo.to_string());
|
||||
@ -420,7 +431,8 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L13bCode
|
||||
| Which::L34bCode
|
||||
| Which::Leo7b
|
||||
| Which::Leo13b => 1,
|
||||
| Which::Leo13b
|
||||
| Which::L8b => 1,
|
||||
Which::Mixtral
|
||||
| Which::MixtralInstruct
|
||||
| Which::Mistral7b
|
||||
@ -537,11 +549,14 @@ fn main() -> anyhow::Result<()> {
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = if args.which.is_open_chat() {
|
||||
"<|end_of_turn|>"
|
||||
} else {
|
||||
"</s>"
|
||||
let eos_token = match args.which {
|
||||
Which::L8b => "<|end_of_text|>",
|
||||
_ => match args.which.is_open_chat() {
|
||||
true => "<|end_of_turn|>",
|
||||
false => "</s>",
|
||||
},
|
||||
};
|
||||
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
|
9
candle-examples/examples/recurrent-gemma/README.md
Normal file
9
candle-examples/examples/recurrent-gemma/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
# candle-recurrent-gemma
|
||||
|
||||
This model card corresponds to the 2B base version of the RecurrentGemma model
|
||||
[huggingface model card](https://huggingface.co/google/recurrentgemma-2b).
|
||||
|
||||
```bash
|
||||
cargo run --features cuda -r --example recurrent-gemma -- \
|
||||
--prompt "Write me a poem about Machine Learning."
|
||||
```
|
321
candle-examples/examples/recurrent-gemma/main.rs
Normal file
321
candle-examples/examples/recurrent-gemma/main.rs
Normal file
@ -0,0 +1,321 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::quantized_recurrent_gemma::Model as QModel;
|
||||
use candle_transformers::models::recurrent_gemma::{Config, Model as BModel};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
B(BModel),
|
||||
Q(QModel),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::B(m) => m.forward(xs, pos),
|
||||
Self::Q(m) => m.forward(xs, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "2b")]
|
||||
Base2B,
|
||||
#[value(name = "2b-it")]
|
||||
Instruct2B,
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
top_k: usize,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let sampling = match temp {
|
||||
None => candle_transformers::generation::Sampling::ArgMax,
|
||||
Some(temperature) => match top_p {
|
||||
None => candle_transformers::generation::Sampling::TopK {
|
||||
temperature,
|
||||
k: top_k,
|
||||
},
|
||||
Some(top_p) => candle_transformers::generation::Sampling::TopKThenTopP {
|
||||
temperature,
|
||||
k: top_k,
|
||||
p: top_p,
|
||||
},
|
||||
},
|
||||
};
|
||||
let logits_processor = LogitsProcessor::from_sampling(seed, sampling);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
#[arg(long, default_value_t = 250)]
|
||||
top_k: usize,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 8000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => model_id.to_string(),
|
||||
None => match args.which {
|
||||
Which::Base2B => "google/recurrentgemma-2b".to_string(),
|
||||
Which::Instruct2B => "google/recurrentgemma-2b-it".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
let filename = match args.which {
|
||||
Which::Base2B => "recurrent-gemma-2b-q4k.gguf",
|
||||
Which::Instruct2B => "recurrent-gemma-7b-q4k.gguf",
|
||||
};
|
||||
let filename = api.model("lmz/candle-gemma".to_string()).get(filename)?;
|
||||
vec![filename]
|
||||
} else {
|
||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let model = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
||||
&filenames[0],
|
||||
&device,
|
||||
)?;
|
||||
Model::Q(QModel::new(&config, vb.pp("model"))?)
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
Model::B(BModel::new(&config, vb.pp("model"))?)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.top_k,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -46,7 +46,8 @@ The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
|
||||
- `--cpu`: use the cpu rather than the gpu (much slower).
|
||||
- `--height`, `--width`: set the height and width for the generated image.
|
||||
- `--n-steps`: the number of steps to be used in the diffusion process.
|
||||
- `--num-samples`: the number of samples to generate.
|
||||
- `--num-samples`: the number of samples to generate iteratively.
|
||||
- `--bsize`: the numbers of samples to generate simultaneously.
|
||||
- `--final-image`: the filename for the generated image(s).
|
||||
|
||||
### Using flash-attention
|
||||
|
@ -9,6 +9,7 @@ use candle_transformers::models::stable_diffusion;
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Module, Tensor, D};
|
||||
use clap::Parser;
|
||||
use stable_diffusion::vae::AutoEncoderKL;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser)]
|
||||
@ -64,9 +65,13 @@ struct Args {
|
||||
#[arg(long)]
|
||||
n_steps: Option<usize>,
|
||||
|
||||
/// The number of samples to generate.
|
||||
/// The number of samples to generate iteratively.
|
||||
#[arg(long, default_value_t = 1)]
|
||||
num_samples: i64,
|
||||
num_samples: usize,
|
||||
|
||||
/// The numbers of samples to generate simultaneously.
|
||||
#[arg[long, default_value_t = 1]]
|
||||
bsize: usize,
|
||||
|
||||
/// The name of the final image to generate.
|
||||
#[arg(long, value_name = "FILE", default_value = "sd_final.png")]
|
||||
@ -236,8 +241,8 @@ impl ModelFile {
|
||||
|
||||
fn output_filename(
|
||||
basename: &str,
|
||||
sample_idx: i64,
|
||||
num_samples: i64,
|
||||
sample_idx: usize,
|
||||
num_samples: usize,
|
||||
timestep_idx: Option<usize>,
|
||||
) -> String {
|
||||
let filename = if num_samples > 1 {
|
||||
@ -261,6 +266,33 @@ fn output_filename(
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn save_image(
|
||||
vae: &AutoEncoderKL,
|
||||
latents: &Tensor,
|
||||
vae_scale: f64,
|
||||
bsize: usize,
|
||||
idx: usize,
|
||||
final_image: &str,
|
||||
num_samples: usize,
|
||||
timestep_ids: Option<usize>,
|
||||
) -> Result<()> {
|
||||
let images = vae.decode(&(latents / vae_scale)?)?;
|
||||
let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?;
|
||||
for batch in 0..bsize {
|
||||
let image = images.i(batch)?;
|
||||
let image_filename = output_filename(
|
||||
final_image,
|
||||
(bsize * idx) + batch + 1,
|
||||
batch + num_samples,
|
||||
timestep_ids,
|
||||
);
|
||||
candle_examples::save_image(&image, image_filename)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn text_embeddings(
|
||||
prompt: &str,
|
||||
@ -382,6 +414,7 @@ fn run(args: Args) -> Result<()> {
|
||||
final_image,
|
||||
sliced_attention_size,
|
||||
num_samples,
|
||||
bsize,
|
||||
sd_version,
|
||||
clip_weights,
|
||||
vae_weights,
|
||||
@ -475,6 +508,7 @@ fn run(args: Args) -> Result<()> {
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
|
||||
let text_embeddings = text_embeddings.repeat((bsize, 1, 1))?;
|
||||
println!("{text_embeddings:?}");
|
||||
|
||||
println!("Building the autoencoder.");
|
||||
@ -496,7 +530,6 @@ fn run(args: Args) -> Result<()> {
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let bsize = 1;
|
||||
|
||||
let vae_scale = match sd_version {
|
||||
StableDiffusionVersion::V1_5
|
||||
@ -560,12 +593,16 @@ fn run(args: Args) -> Result<()> {
|
||||
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
|
||||
|
||||
if args.intermediary_images {
|
||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename =
|
||||
output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1));
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
save_image(
|
||||
&vae,
|
||||
&latents,
|
||||
vae_scale,
|
||||
bsize,
|
||||
idx,
|
||||
&final_image,
|
||||
num_samples,
|
||||
Some(timestep_index + 1),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
@ -574,11 +611,16 @@ fn run(args: Args) -> Result<()> {
|
||||
idx + 1,
|
||||
num_samples
|
||||
);
|
||||
let image = vae.decode(&(&latents / vae_scale)?)?;
|
||||
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
|
||||
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
|
||||
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);
|
||||
candle_examples::save_image(&image, image_filename)?
|
||||
save_image(
|
||||
&vae,
|
||||
&latents,
|
||||
vae_scale,
|
||||
bsize,
|
||||
idx,
|
||||
&final_image,
|
||||
num_samples,
|
||||
None,
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -115,7 +115,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
||||
|
||||
let image = vec![args.image.as_str()];
|
||||
let image = processor.preprocess(image)?;
|
||||
let image = processor.preprocess(image)?.to_device(&device)?;
|
||||
|
||||
let encoder_xs = model.encoder().forward(&image)?;
|
||||
|
||||
|
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
BIN
candle-examples/examples/yolo-v8/assets/bike.pp.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 175 KiB |
File diff suppressed because it is too large
Load Diff
@ -207,6 +207,9 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
|
||||
GATHER_OP(gather_u32_f32, uint, float)
|
||||
GATHER_OP(gather_u32_f16, uint, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
GATHER_OP(gather_u32_bf16, uint, bfloat)
|
||||
#endif
|
||||
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
||||
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
||||
|
@ -1,11 +1,15 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
|
||||
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
mod utils;
|
||||
pub use utils::BufferOffset;
|
||||
use utils::{get_block_dims, linear_split};
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
@ -18,100 +22,6 @@ const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
/// actual total buffer length).
|
||||
/// Then kernels can just do their op on their single point in the buffer.
|
||||
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
||||
let count = (size + width - 1) / width;
|
||||
let thread_group_count = MTLSize {
|
||||
width: count,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
(thread_group_count, thread_group_size)
|
||||
}
|
||||
|
||||
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
||||
<P as EncoderParam>::set_param(encoder, position, data)
|
||||
}
|
||||
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
trait EncoderParam {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
impl EncoderParam for $type {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of::<$type>() as u64,
|
||||
&data as *const $type as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of_val(data) as u64,
|
||||
data.as_ptr() as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for &Buffer {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data), 0);
|
||||
}
|
||||
}
|
||||
impl EncoderParam for (&Buffer, usize) {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||
}
|
||||
}
|
||||
impl EncoderParam for &mut Buffer {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data), 0);
|
||||
}
|
||||
}
|
||||
impl EncoderParam for (&mut Buffer, usize) {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! set_params {
|
||||
($encoder:ident, ($($param:expr),+)) => (
|
||||
let mut _index = 0;
|
||||
$(
|
||||
set_param($encoder, _index, $param);
|
||||
_index += 1;
|
||||
)*
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
Affine,
|
||||
@ -235,6 +145,12 @@ pub struct Kernels {
|
||||
pipelines: RwLock<Pipelines>,
|
||||
}
|
||||
|
||||
impl Default for Kernels {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Kernels {
|
||||
pub fn new() -> Self {
|
||||
let libraries = RwLock::new(Libraries::new());
|
||||
@ -358,17 +274,17 @@ pub fn call_unary_contiguous(
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, output));
|
||||
set_params!(encoder, (length, &input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -396,21 +312,24 @@ pub fn call_copy2d(
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
d1,
|
||||
d2,
|
||||
src_s,
|
||||
dst_s,
|
||||
d1 as i64,
|
||||
d2 as i64,
|
||||
src_s as i64,
|
||||
dst_s as i64,
|
||||
(input, src_o_in_bytes),
|
||||
(output, dst_o_in_bytes)
|
||||
)
|
||||
);
|
||||
|
||||
let width: usize = d1 * d2;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width / 4);
|
||||
|
||||
let grid_dims = MTLSize {
|
||||
width: d1 as u64,
|
||||
height: d2 as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let group_dims = get_block_dims(d1 as u64, d2 as u64, 1);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.dispatch_threads(grid_dims, group_dims);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
@ -422,11 +341,9 @@ pub fn call_unary_strided(
|
||||
kernels: &Kernels,
|
||||
name: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input: BufferOffset,
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
output: &Buffer,
|
||||
output_offset: usize,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
@ -435,23 +352,13 @@ pub fn call_unary_strided(
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
length,
|
||||
num_dims,
|
||||
shape,
|
||||
strides,
|
||||
(input, offset),
|
||||
(output, output_offset)
|
||||
)
|
||||
);
|
||||
set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
|
||||
|
||||
let width: usize = shape.iter().product();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
@ -464,8 +371,8 @@ pub fn call_binary_contiguous(
|
||||
kernels: &Kernels,
|
||||
kernel_name: binary::contiguous::Kernel,
|
||||
length: usize,
|
||||
left: &Buffer,
|
||||
right: &Buffer,
|
||||
left: BufferOffset,
|
||||
right: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
@ -473,12 +380,12 @@ pub fn call_binary_contiguous(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, left, right, output));
|
||||
set_params!(encoder, (length, &left, &right, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.use_resource(left, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -492,12 +399,10 @@ pub fn call_binary_strided(
|
||||
kernels: &Kernels,
|
||||
name: binary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
left_input: &Buffer,
|
||||
left_input: BufferOffset,
|
||||
left_strides: &[usize],
|
||||
left_offset: usize,
|
||||
right_input: &Buffer,
|
||||
right_input: BufferOffset,
|
||||
right_strides: &[usize],
|
||||
right_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||
@ -517,16 +422,16 @@ pub fn call_binary_strided(
|
||||
shape,
|
||||
left_strides,
|
||||
right_strides,
|
||||
(left_input, left_offset),
|
||||
(right_input, right_offset),
|
||||
&left_input,
|
||||
&right_input,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -540,8 +445,7 @@ pub fn call_cast_contiguous(
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
@ -549,10 +453,10 @@ pub fn call_cast_contiguous(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, (input, input_offset), output));
|
||||
set_params!(encoder, (length, &input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -566,9 +470,8 @@ pub fn call_cast_strided(
|
||||
kernels: &Kernels,
|
||||
kernel_name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input: BufferOffset,
|
||||
input_strides: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
@ -580,25 +483,19 @@ pub fn call_cast_strided(
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
length,
|
||||
shape.len(),
|
||||
shape,
|
||||
input_strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(length, shape.len(), shape, input_strides, &input, output)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_reduce_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -606,8 +503,7 @@ pub fn call_reduce_contiguous(
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
@ -616,10 +512,7 @@ pub fn call_reduce_contiguous(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(length, elements_to_sum, (input, input_offset), output)
|
||||
);
|
||||
set_params!(encoder, (length, elements_to_sum, &input, output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -639,13 +532,14 @@ pub fn call_reduce_contiguous(
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_reduce_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -654,8 +548,7 @@ pub fn call_reduce_strided(
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let length: usize = shape.iter().product();
|
||||
@ -667,14 +560,7 @@ pub fn call_reduce_strided(
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
shape.len(),
|
||||
shape,
|
||||
strides,
|
||||
elements_to_sum,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(shape.len(), shape, strides, elements_to_sum, &input, output)
|
||||
);
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
@ -695,7 +581,7 @@ pub fn call_reduce_strided(
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -944,7 +830,7 @@ pub fn call_affine(
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
@ -954,10 +840,10 @@ pub fn call_affine(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, add, input, output));
|
||||
set_params!(encoder, (size, mul, add, &input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -971,9 +857,8 @@ pub fn call_affine_strided(
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input: BufferOffset,
|
||||
input_stride: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
@ -993,13 +878,13 @@ pub fn call_affine_strided(
|
||||
input_stride,
|
||||
mul,
|
||||
add,
|
||||
(input, input_offset),
|
||||
&input,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1013,7 +898,7 @@ pub fn call_powf(
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
@ -1022,10 +907,10 @@ pub fn call_powf(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
set_params!(encoder, (size, mul, &input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1039,9 +924,8 @@ pub fn call_powf_strided(
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input: BufferOffset,
|
||||
input_stride: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
@ -1053,19 +937,11 @@ pub fn call_powf_strided(
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
size,
|
||||
shape.len(),
|
||||
shape,
|
||||
input_stride,
|
||||
mul,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(size, shape.len(), shape, input_stride, mul, &input, output)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1079,7 +955,7 @@ pub fn call_elu(
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
@ -1088,10 +964,10 @@ pub fn call_elu(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (size, mul, input, output));
|
||||
set_params!(encoder, (size, mul, &input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1105,9 +981,8 @@ pub fn call_elu_strided(
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
input: &Buffer,
|
||||
input: BufferOffset,
|
||||
input_stride: &[usize],
|
||||
input_offset: usize,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
@ -1119,37 +994,30 @@ pub fn call_elu_strided(
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
size,
|
||||
shape.len(),
|
||||
shape,
|
||||
input_stride,
|
||||
mul,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(size, shape.len(), shape, input_stride, mul, &input, output)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_where_cond_strided(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
shape: &[usize],
|
||||
cond: &Buffer,
|
||||
(cond_stride, cond_offset): (&[usize], usize),
|
||||
left: &Buffer,
|
||||
(left_stride, left_offset): (&[usize], usize),
|
||||
right: &Buffer,
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
cond: BufferOffset,
|
||||
cond_stride: &[usize],
|
||||
left: BufferOffset,
|
||||
left_stride: &[usize],
|
||||
right: BufferOffset,
|
||||
right_stride: &[usize],
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
@ -1169,18 +1037,18 @@ pub fn call_where_cond_strided(
|
||||
cond_stride,
|
||||
left_stride,
|
||||
right_stride,
|
||||
(cond, cond_offset),
|
||||
(left, left_offset),
|
||||
(right, right_offset),
|
||||
&cond,
|
||||
&left,
|
||||
&right,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||
|
||||
encoder.use_resource(cond, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(left, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(cond.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1199,10 +1067,8 @@ pub fn call_index_select(
|
||||
contiguous: bool,
|
||||
src_dims: &[usize],
|
||||
src_strides: &[usize],
|
||||
input: &Buffer,
|
||||
src_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
@ -1227,16 +1093,16 @@ pub fn call_index_select(
|
||||
contiguous,
|
||||
src_dims,
|
||||
src_strides,
|
||||
(input, src_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1252,10 +1118,8 @@ pub fn call_gather(
|
||||
shape: &[usize],
|
||||
ids_size: usize,
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
@ -1277,22 +1141,23 @@ pub fn call_gather(
|
||||
src_dim_size,
|
||||
right_size,
|
||||
ids_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_scatter_add(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -1301,10 +1166,8 @@ pub fn call_scatter_add(
|
||||
src_shape: &[usize],
|
||||
dst_shape: &[usize],
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
@ -1327,22 +1190,23 @@ pub fn call_scatter_add(
|
||||
src_dim_size,
|
||||
right_size,
|
||||
dst_dim_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_index_add(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -1352,10 +1216,8 @@ pub fn call_index_add(
|
||||
dst_shape: &[usize],
|
||||
ids_shape: &[usize],
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
ids: &Buffer,
|
||||
ids_offset: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
@ -1379,16 +1241,16 @@ pub fn call_index_add(
|
||||
right_size,
|
||||
dst_dim_size,
|
||||
ids_dim_size,
|
||||
(input, input_offset),
|
||||
(ids, ids_offset),
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1654,8 +1516,7 @@ pub fn call_im2col1d_strided(
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(k_size, stride, padding, dilation): (usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
@ -1667,20 +1528,9 @@ pub fn call_im2col1d_strided(
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
l_out,
|
||||
k_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(dst_el, l_out, k_size, stride, padding, dilation, shape, strides, &input, output)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1697,8 +1547,7 @@ pub fn call_im2col_strided(
|
||||
shape: &[usize],
|
||||
strides: &[usize],
|
||||
(h_k, w_k, stride, padding, dilation): (usize, usize, usize, usize, usize),
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
@ -1716,21 +1565,11 @@ pub fn call_im2col_strided(
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
dst_el,
|
||||
h_out,
|
||||
w_out,
|
||||
h_k,
|
||||
w_k,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
dst_el, h_out, w_out, h_k, w_k, stride, padding, dilation, shape, strides, &input,
|
||||
output
|
||||
)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1748,8 +1587,7 @@ pub fn call_upsample_nearest_2d(
|
||||
strides: &[usize],
|
||||
out_w: usize,
|
||||
out_h: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||
@ -1761,18 +1599,9 @@ pub fn call_upsample_nearest_2d(
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
out_w,
|
||||
out_h,
|
||||
scale_w,
|
||||
scale_h,
|
||||
shape,
|
||||
strides,
|
||||
(input, input_offset),
|
||||
output
|
||||
)
|
||||
(out_w, out_h, scale_w, scale_h, shape, strides, &input, output)
|
||||
);
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1869,6 +1698,7 @@ pub enum GgmlDType {
|
||||
F32,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_quantized_matmul_t(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
@ -1884,16 +1714,16 @@ pub fn call_quantized_matmul_t(
|
||||
let ne00 = k as i64;
|
||||
let ne01 = n as i64;
|
||||
let ne02 = b as i64;
|
||||
let ne03 = 1 as i64;
|
||||
let ne03 = 1i64;
|
||||
|
||||
let nb00 = 0i64;
|
||||
let nb01 = 0 as i64;
|
||||
let nb02 = 0 as i64;
|
||||
let nb01 = 0i64;
|
||||
let nb02 = 0i64;
|
||||
|
||||
let ne10 = k as i64;
|
||||
let ne11 = m as i64;
|
||||
let ne12 = b as i64;
|
||||
let ne13 = 1 as i64;
|
||||
let ne13 = 1i64;
|
||||
|
||||
let nb10 = 0i64;
|
||||
let nb11 = 0i64;
|
||||
@ -2128,6 +1958,7 @@ pub struct CallConvTranspose2dCfg<'a> {
|
||||
pub kernel_offset: usize,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_conv_transpose2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
|
@ -12,7 +12,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let ptr = data.as_ptr() as *const c_void;
|
||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||
let size = std::mem::size_of_val(data) as u64;
|
||||
device.new_buffer_with_data(ptr, size, options)
|
||||
}
|
||||
|
||||
@ -41,6 +41,10 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let input = BufferOffset {
|
||||
buffer: &input,
|
||||
offset_in_bytes: 0,
|
||||
};
|
||||
let output = new_buffer(&device, v);
|
||||
call_unary_contiguous(
|
||||
&device,
|
||||
@ -48,7 +52,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
input,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -72,8 +76,8 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
||||
&kernels,
|
||||
name,
|
||||
x.len(),
|
||||
&left,
|
||||
&right,
|
||||
BufferOffset::zero_offset(&left),
|
||||
BufferOffset::zero_offset(&right),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -93,7 +97,15 @@ fn run_strided<T: Clone>(
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, v);
|
||||
let input = BufferOffset {
|
||||
buffer: &input,
|
||||
offset_in_bytes: offset,
|
||||
};
|
||||
let output_b = new_buffer(&device, v);
|
||||
let output = BufferOffset {
|
||||
buffer: &output_b,
|
||||
offset_in_bytes: 0,
|
||||
};
|
||||
let kernels = Kernels::new();
|
||||
call_unary_strided(
|
||||
&device,
|
||||
@ -101,16 +113,14 @@ fn run_strided<T: Clone>(
|
||||
&kernels,
|
||||
kernel,
|
||||
shape,
|
||||
&input,
|
||||
input,
|
||||
strides,
|
||||
offset,
|
||||
&output,
|
||||
0,
|
||||
output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec(&output, v.len())
|
||||
read_to_vec(&output_b, v.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -308,8 +318,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
&kernels,
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -521,7 +530,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
&kernels,
|
||||
"affine_f32",
|
||||
size,
|
||||
&input,
|
||||
BufferOffset::zero_offset(&input),
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -554,9 +563,8 @@ fn run_affine_strided<T: Clone>(
|
||||
&kernels,
|
||||
"affine_f32_strided",
|
||||
shape,
|
||||
&input,
|
||||
BufferOffset::zero_offset(&input),
|
||||
strides,
|
||||
0,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
@ -633,7 +641,7 @@ fn index_select_strided() {
|
||||
fn index_select_f16() {
|
||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
.into_iter()
|
||||
.map(|x| f16::from_f32(x))
|
||||
.map(f16::from_f32)
|
||||
.collect();
|
||||
let shape = [5, 2];
|
||||
let stride = [2, 1];
|
||||
@ -700,8 +708,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||
let ids_buffer = new_buffer(&device, &ids);
|
||||
let embeddings_buffer = new_buffer(&device, embeddings);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
@ -711,7 +719,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
@ -720,10 +728,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
true,
|
||||
shape,
|
||||
stride,
|
||||
&embeddings_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&embeddings_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
@ -746,8 +752,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
||||
let ids_buffer = new_buffer(&device, &ids);
|
||||
let embeddings_buffer = new_buffer(&device, embeddings);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
@ -757,7 +763,7 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let kernels = Kernels::new();
|
||||
call_index_select(
|
||||
&device,
|
||||
&command_buffer,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
@ -766,10 +772,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
false,
|
||||
shape,
|
||||
stride,
|
||||
&embeddings_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&embeddings_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
@ -811,8 +815,7 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
||||
&dims,
|
||||
&strides,
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -931,6 +934,7 @@ fn softmax() {
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_where_cond<I: Clone, T: Clone>(
|
||||
shape: &[usize],
|
||||
cond: &[I],
|
||||
@ -965,18 +969,30 @@ fn run_where_cond<I: Clone, T: Clone>(
|
||||
);
|
||||
|
||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
|
||||
let cond = BufferOffset {
|
||||
buffer: &cond,
|
||||
offset_in_bytes: cond_offset,
|
||||
};
|
||||
let left = BufferOffset {
|
||||
buffer: &left,
|
||||
offset_in_bytes: left_offset,
|
||||
};
|
||||
let right = BufferOffset {
|
||||
buffer: &right,
|
||||
offset_in_bytes: cond_offset,
|
||||
};
|
||||
call_where_cond_strided(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
&cond,
|
||||
(&cond_stride, cond_offset),
|
||||
&left,
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
cond,
|
||||
&cond_stride,
|
||||
left,
|
||||
&left_stride,
|
||||
right,
|
||||
&cond_stride,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1148,7 +1164,7 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b:
|
||||
#[test]
|
||||
fn random() {
|
||||
fn calc_mean(data: &[f32]) -> f32 {
|
||||
let sum = data.iter().sum::<f32>() as f32;
|
||||
let sum = data.iter().sum::<f32>();
|
||||
let count = data.len();
|
||||
assert!(count > 0);
|
||||
sum / count as f32
|
||||
@ -1162,7 +1178,7 @@ fn random() {
|
||||
let variance = data
|
||||
.iter()
|
||||
.map(|value| {
|
||||
let diff = mean - (*value as f32);
|
||||
let diff = mean - *value;
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
@ -1241,10 +1257,8 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&ids_buffer),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1346,10 +1360,8 @@ fn run_index_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&indices_buffer,
|
||||
0,
|
||||
BufferOffset::zero_offset(&input_buffer),
|
||||
BufferOffset::zero_offset(&indices_buffer),
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
@ -1787,6 +1799,7 @@ fn avg_pool2d_u32() {
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn run_conv_transpose1d<T: Clone>(
|
||||
input: &[T],
|
||||
input_shape: &[usize],
|
||||
|
@ -104,26 +104,18 @@ UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define COPY2D(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &d1, \
|
||||
constant size_t &d2, \
|
||||
constant size_t &src_s, \
|
||||
constant size_t &dst_s, \
|
||||
constant int64_t &d1, \
|
||||
constant int64_t &d2, \
|
||||
constant int64_t &src_s, \
|
||||
constant int64_t &dst_s, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
uint2 idx [[thread_position_in_grid]] \
|
||||
) { \
|
||||
tid *= 4; \
|
||||
if (tid >= d1 * d2) { \
|
||||
return; \
|
||||
} \
|
||||
size_t idx1 = tid / d2; \
|
||||
size_t idx2 = tid - idx1 * d2; \
|
||||
size_t src_idx = idx1 * src_s + idx2; \
|
||||
size_t dst_idx = idx1 * dst_s + idx2; \
|
||||
if (idx.x >= d1 || idx.y >= d2) return; \
|
||||
int64_t src_idx = idx.x * src_s + idx.y; \
|
||||
int64_t dst_idx = idx.x * dst_s + idx.y; \
|
||||
output[dst_idx] = input[src_idx]; \
|
||||
output[dst_idx+1] = input[src_idx+1]; \
|
||||
output[dst_idx+2] = input[src_idx+2]; \
|
||||
output[dst_idx+3] = input[src_idx+3]; \
|
||||
}
|
||||
|
||||
COPY2D(copy2d_f32, float)
|
||||
@ -183,5 +175,5 @@ BFLOAT_UNARY_OP(sign)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
|
||||
COPY2D(copy2d_bf64, bfloat)
|
||||
COPY2D(copy2d_bf16, bfloat)
|
||||
#endif
|
||||
|
162
candle-metal-kernels/src/utils.rs
Normal file
162
candle-metal-kernels/src/utils.rs
Normal file
@ -0,0 +1,162 @@
|
||||
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
|
||||
use std::ffi::c_void;
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
/// actual total buffer length).
|
||||
/// Then kernels can just do their op on their single point in the buffer.
|
||||
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
||||
let count = (size + width - 1) / width;
|
||||
let thread_group_count = MTLSize {
|
||||
width: count,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
(thread_group_count, thread_group_size)
|
||||
}
|
||||
|
||||
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
||||
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
||||
let mut pows0 = 0u64;
|
||||
let mut pows1 = 0u64;
|
||||
let mut pows2 = 0u64;
|
||||
let mut sum = 0u64;
|
||||
loop {
|
||||
let presum = sum;
|
||||
// Check all the pows
|
||||
if dim0 >= (1 << (pows0 + 1)) {
|
||||
pows0 += 1;
|
||||
sum += 1;
|
||||
}
|
||||
if sum == 10 {
|
||||
break;
|
||||
}
|
||||
if dim1 >= (1 << (pows1 + 1)) {
|
||||
pows1 += 1;
|
||||
sum += 1;
|
||||
}
|
||||
if sum == 10 {
|
||||
break;
|
||||
}
|
||||
if dim2 >= (1 << (pows2 + 1)) {
|
||||
pows2 += 1;
|
||||
sum += 1;
|
||||
}
|
||||
if sum == presum || sum == 10 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
MTLSize {
|
||||
width: 1 << pows0,
|
||||
height: 1 << pows1,
|
||||
depth: 1 << pows2,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_param<P: EncoderParam>(
|
||||
encoder: &ComputeCommandEncoderRef,
|
||||
position: u64,
|
||||
data: P,
|
||||
) {
|
||||
<P as EncoderParam>::set_param(encoder, position, data)
|
||||
}
|
||||
|
||||
/// Helper functions to create the various objects on the compute command encoder
|
||||
/// on a single line.
|
||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||
pub(crate) trait EncoderParam {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||
}
|
||||
macro_rules! primitive {
|
||||
($type:ty) => {
|
||||
impl EncoderParam for $type {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of::<$type>() as u64,
|
||||
&data as *const $type as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
|
||||
pub struct BufferOffset<'a> {
|
||||
pub buffer: &'a Buffer,
|
||||
pub offset_in_bytes: usize,
|
||||
}
|
||||
|
||||
impl<'a> BufferOffset<'a> {
|
||||
pub fn zero_offset(buffer: &'a Buffer) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
offset_in_bytes: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> EncoderParam for &[T] {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_bytes(
|
||||
position,
|
||||
core::mem::size_of_val(data) as u64,
|
||||
data.as_ptr() as *const c_void,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for &Buffer {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data), 0);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for (&Buffer, usize) {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> EncoderParam for &BufferOffset<'a> {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for &mut Buffer {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data), 0);
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderParam for (&mut Buffer, usize) {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! set_params {
|
||||
($encoder:ident, ($($param:expr),+)) => (
|
||||
let mut _index = 0;
|
||||
$(
|
||||
$crate::utils::set_param($encoder, _index, $param);
|
||||
_index += 1;
|
||||
)*
|
||||
);
|
||||
}
|
@ -498,6 +498,53 @@ impl<'a> VarBuilder<'a> {
|
||||
let pth = candle::pickle::PthTensors::new(p, None)?;
|
||||
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
|
||||
/// passing the new names to the inner VarBuilder.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
///
|
||||
/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
|
||||
/// let tensors: std::collections::HashMap<_, _> = [
|
||||
/// ("foo".to_string(), a),
|
||||
/// ]
|
||||
/// .into_iter()
|
||||
/// .collect();
|
||||
/// let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
|
||||
/// assert!(vb.contains_tensor("foo"));
|
||||
/// assert!(vb.get((2, 3), "foo").is_ok());
|
||||
/// assert!(!vb.contains_tensor("bar"));
|
||||
/// let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
|
||||
/// assert!(vb.contains_tensor("bar"));
|
||||
/// assert!(vb.contains_tensor("foo"));
|
||||
/// assert!(vb.get((2, 3), "bar").is_ok());
|
||||
/// assert!(vb.get((2, 3), "foo").is_ok());
|
||||
/// assert!(!vb.contains_tensor("baz"));
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self {
|
||||
let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f);
|
||||
self.rename(f)
|
||||
}
|
||||
|
||||
pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self {
|
||||
let dtype = self.dtype();
|
||||
let device = self.device().clone();
|
||||
let path = self.path.clone();
|
||||
let backend = Rename::new(self, renamer);
|
||||
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
path,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
|
||||
@ -618,3 +665,49 @@ impl Backend for ShardedSafeTensors {
|
||||
self.0.get(name).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// This traits specifies a way to rename the queried names into names that are stored in an inner
|
||||
/// VarBuilder.
|
||||
pub trait Renamer {
|
||||
/// This is applied to the name obtained by a name call and the resulting name is passed to the
|
||||
/// inner VarBuilder.
|
||||
fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>;
|
||||
}
|
||||
|
||||
pub struct Rename<'a, R: Renamer> {
|
||||
inner: VarBuilder<'a>,
|
||||
renamer: R,
|
||||
}
|
||||
|
||||
impl<'a, R: Renamer + Sync + Send> SimpleBackend for Rename<'a, R> {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
h: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let name = self.renamer.rename(name);
|
||||
self.inner
|
||||
.get_with_hints_dtype(s, &name, h, dtype)?
|
||||
.to_device(dev)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
let name = self.renamer.rename(name);
|
||||
self.inner.contains_tensor(&name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R: Renamer> Rename<'a, R> {
|
||||
pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self {
|
||||
Self { inner, renamer }
|
||||
}
|
||||
}
|
||||
|
||||
impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> {
|
||||
fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> {
|
||||
std::borrow::Cow::Owned(self(v))
|
||||
}
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ use crate::onnx;
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, usize};
|
||||
|
||||
pub type Value = Tensor;
|
||||
|
||||
@ -508,17 +508,33 @@ pub fn simple_eval(
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Gather" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
|
||||
let xs = get(&node.input[0])?;
|
||||
let indices = get(&node.input[1])?;
|
||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = xs.normalize_axis(axis)?;
|
||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||
// differentiable way.
|
||||
let xs = if indices.rank() == 0 {
|
||||
let index = indices.to_vec0::<i64>()? as usize;
|
||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||
} else {
|
||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||
|
||||
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
|
||||
// tensor directly, but candle does not support tensor indexing at the moment, so
|
||||
// some workarounds must be done.
|
||||
let xs = match indices.dims() {
|
||||
[] => {
|
||||
let index = indices.to_vec0::<i64>()? as usize;
|
||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||
}
|
||||
[_] => xs.index_select(indices, axis)?,
|
||||
[first, _] => {
|
||||
let mut v = Vec::with_capacity(*first);
|
||||
for i in 0..*first {
|
||||
v.push(xs.index_select(&indices.get(i)?, axis)?)
|
||||
}
|
||||
Tensor::stack(&v, axis)?
|
||||
}
|
||||
_ => {
|
||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||
// differentiable way.
|
||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
@ -781,6 +797,29 @@ pub fn simple_eval(
|
||||
let input = get(&node.input[0])?;
|
||||
values.insert(node.output[0].clone(), input.clone());
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||
"ReduceMean" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axes = get_attr_opt::<[i64]>(node, "axes")?;
|
||||
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
|
||||
|
||||
let n_dims = input.dims().len();
|
||||
|
||||
let axes: Vec<usize> = if let Some(axes) = axes {
|
||||
axes.iter()
|
||||
.map(|e| (if e < &0 { (n_dims as i64) + *e } else { *e }) as usize)
|
||||
.collect()
|
||||
} else {
|
||||
(0..n_dims).collect()
|
||||
};
|
||||
let output = if keepdims == 1 {
|
||||
input.mean_keepdim(axes)?
|
||||
} else {
|
||||
input.mean(axes)?
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle::{Device, NdArray, Result, Tensor};
|
||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -829,7 +829,134 @@ fn test_flatten_operation() -> Result<()> {
|
||||
// #[test]
|
||||
|
||||
// "Gather"
|
||||
// #[test]
|
||||
#[test]
|
||||
fn test_gather_operation() -> Result<()> {
|
||||
// test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
|
||||
test(
|
||||
&[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]],
|
||||
&[[0i64, 1], [1, 2]],
|
||||
0,
|
||||
&[[[1.0, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]],
|
||||
)?;
|
||||
|
||||
// test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
|
||||
test(
|
||||
&[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]],
|
||||
&[[0i64, 2]],
|
||||
1,
|
||||
&[[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]]],
|
||||
)?;
|
||||
|
||||
// all the tests below are generated from numpy.take, which works like
|
||||
// onnx's Gather operation.
|
||||
test(&[1.0, 2.0, 3.0, 4.0], 3i64, 0, 4.0)?;
|
||||
|
||||
test(&[[1.0, 2.0, 3.0, 4.0]], 3i64, 1, &[4.0])?;
|
||||
|
||||
test(
|
||||
&[[1.0], [2.0], [3.0], [4.0]],
|
||||
&[3i64, 2],
|
||||
0,
|
||||
&[[4.0], [3.0]],
|
||||
)?;
|
||||
|
||||
test(
|
||||
&[
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
[[13.0, 14.0], [15.0, 16.0]],
|
||||
],
|
||||
1i64,
|
||||
0,
|
||||
&[[5.0, 6.0], [7.0, 8.0]],
|
||||
)?;
|
||||
|
||||
test(
|
||||
&[
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
[[13.0, 14.0], [15.0, 16.0]],
|
||||
],
|
||||
&[1i64, 0],
|
||||
0,
|
||||
&[[[5.0, 6.0], [7.0, 8.0]], [[1.0, 2.0], [3.0, 4.0]]],
|
||||
)?;
|
||||
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
indices: impl NdArray,
|
||||
axis: i64,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let att_axis = AttributeProto {
|
||||
name: "axis".to_string(),
|
||||
ref_attr_name: "axis".to_string(),
|
||||
i: axis,
|
||||
doc_string: "axis".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Gather".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_axis],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Shape"
|
||||
#[test]
|
||||
@ -1335,3 +1462,180 @@ fn test_relu_operation() -> Result<()> {
|
||||
|
||||
// "Cast"
|
||||
// #[test]
|
||||
|
||||
// "ReduceMean"
|
||||
#[test]
|
||||
fn test_reduce_mean() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 default_axes_keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
None,
|
||||
1,
|
||||
&[[[18.25]]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 do_no_keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1]),
|
||||
0,
|
||||
&[[12.5, 1.5], [35.0, 1.5], [57.5, 1.5]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1]),
|
||||
1,
|
||||
&[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 negative_axes_keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![-2]),
|
||||
1,
|
||||
&[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]],
|
||||
)?;
|
||||
|
||||
// All the test data below was generated based on numpy's np.mean
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1, 2]),
|
||||
0,
|
||||
&[7.0, 18.25, 29.5],
|
||||
)?;
|
||||
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1, 2]),
|
||||
1,
|
||||
&[[[7.0]], [[18.25]], [[29.5]]],
|
||||
)?;
|
||||
|
||||
test(&[1., 2., 3.], None, 1, &[2.0])?;
|
||||
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
axes: Option<Vec<i64>>,
|
||||
keepdims: i64,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let has_axes = axes.is_some();
|
||||
|
||||
let att_axes = AttributeProto {
|
||||
name: "axes".to_string(),
|
||||
ref_attr_name: "axes".to_string(),
|
||||
i: 0,
|
||||
doc_string: "axes".to_string(),
|
||||
r#type: 7,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: axes.unwrap_or_default(),
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let att_keepdims = AttributeProto {
|
||||
name: "keepdims".to_string(),
|
||||
ref_attr_name: "keepdims".to_string(),
|
||||
i: keepdims,
|
||||
doc_string: "keepdims".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "ReduceMean".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: if has_axes {
|
||||
vec![att_axes, att_keepdims]
|
||||
} else {
|
||||
vec![att_keepdims]
|
||||
},
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ fn rotate_half(x: &Tensor) -> Result<Tensor> {
|
||||
Ok(x21)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct FalconRotaryEmbedding {
|
||||
inv_freq: Tensor,
|
||||
cache: Option<(usize, Tensor, Tensor)>,
|
||||
@ -179,12 +179,14 @@ impl FalconRotaryEmbedding {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let on_true = Tensor::new(on_true, on_false.device())?
|
||||
.to_dtype(on_false.dtype())?
|
||||
.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct FalconAttention {
|
||||
query_key_value: Linear,
|
||||
dense: Linear,
|
||||
@ -313,9 +315,13 @@ impl FalconAttention {
|
||||
let attn_output = self.dense.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct FalconMlp {
|
||||
dense_h_to_4h: Linear,
|
||||
dense_4h_to_h: Linear,
|
||||
@ -340,7 +346,7 @@ impl FalconMlp {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct FalconDecoderLayer {
|
||||
inp_layernorm: LayerNorm,
|
||||
self_attention: FalconAttention,
|
||||
@ -400,9 +406,13 @@ impl FalconDecoderLayer {
|
||||
let output = (mlp_output + residual)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.self_attention.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Falcon {
|
||||
word_embeddings: Embedding,
|
||||
blocks: Vec<FalconDecoderLayer>,
|
||||
@ -475,4 +485,10 @@ impl Falcon {
|
||||
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for block in self.blocks.iter_mut() {
|
||||
block.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b as linear, Linear, VarBuilder};
|
||||
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
|
||||
|
||||
fn default_max_position_embeddings() -> usize {
|
||||
4096
|
||||
@ -11,7 +11,9 @@ fn default_max_position_embeddings() -> usize {
|
||||
pub struct Config {
|
||||
pub attention_bias: bool,
|
||||
pub head_dim: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
// The code gemma configs include both hidden_act and hidden_activation.
|
||||
pub hidden_act: Option<Activation>,
|
||||
pub hidden_activation: Option<Activation>,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
@ -25,6 +27,16 @@ pub struct Config {
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn hidden_act(&self) -> Result<Activation> {
|
||||
match (self.hidden_act, self.hidden_activation) {
|
||||
(None, Some(act)) | (Some(act), None) => Ok(act),
|
||||
(Some(_), Some(_)) => candle::bail!("both hidden_act and hidden_activation are set"),
|
||||
(None, None) => candle::bail!("none of hidden_act and hidden_activation are set"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
weight: Tensor,
|
||||
@ -126,7 +138,7 @@ impl MLP {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
act_fn: cfg.hidden_act()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -179,18 +191,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -227,8 +227,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
|
@ -16,6 +16,8 @@ pub struct LlamaConfig {
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
@ -34,6 +36,8 @@ impl LlamaConfig {
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
bos_token_id: self.bos_token_id,
|
||||
eos_token_id: self.eos_token_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -49,6 +53,8 @@ pub struct Config {
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -63,6 +69,8 @@ impl Config {
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-6,
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -77,6 +85,8 @@ impl Config {
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -106,7 +116,6 @@ impl Cache {
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
@ -166,16 +175,10 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
|
||||
let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
|
||||
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
candle_nn::rotary_emb::rope(x, &cos, &sin)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
@ -193,10 +196,12 @@ impl CausalSelfAttention {
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
@ -256,17 +261,7 @@ impl CausalSelfAttention {
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
crate::utils::repeat_kv(x, self.num_attention_heads / self.num_key_value_heads)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(unused)]
|
||||
/// A fast implementation of mamba for inference only.
|
||||
/// This is based on: https://github.com/LaurentMazare/mamba.rs
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||
@ -38,12 +37,12 @@ pub struct State {
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
|
||||
pub fn new(batch_size: usize, cfg: &Config, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let mut hs = Vec::with_capacity(cfg.n_layer);
|
||||
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
|
||||
for _i in 0..cfg.n_layer {
|
||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
|
||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
|
||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), dtype, device)?;
|
||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), dtype, device)?;
|
||||
hs.push(h);
|
||||
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
|
||||
}
|
||||
@ -128,8 +127,8 @@ impl MambaBlock {
|
||||
let delta = delta.apply(&self.dt_proj)?;
|
||||
// softplus
|
||||
let delta = (delta.exp()? + 1.)?.log()?;
|
||||
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(candle::DType::F32)?;
|
||||
let a = self.a_log.to_dtype(delta.dtype())?.exp()?.neg()?;
|
||||
let d = self.d.to_dtype(delta.dtype())?;
|
||||
|
||||
// Selective scan part
|
||||
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
|
||||
@ -178,6 +177,7 @@ pub struct Model {
|
||||
layers: Vec<ResidualBlock>,
|
||||
norm_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -196,6 +196,7 @@ impl Model {
|
||||
layers,
|
||||
norm_f,
|
||||
lm_head,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -208,4 +209,8 @@ impl Model {
|
||||
state.pos += 1;
|
||||
xs.apply(&self.norm_f)?.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.dtype
|
||||
}
|
||||
}
|
||||
|
@ -216,18 +216,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -266,8 +254,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
|
@ -158,18 +158,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -206,8 +194,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
|
@ -37,12 +37,14 @@ pub mod quantized_mistral;
|
||||
pub mod quantized_mixformer;
|
||||
pub mod quantized_moondream;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_recurrent_gemma;
|
||||
pub mod quantized_rwkv_v5;
|
||||
pub mod quantized_rwkv_v6;
|
||||
pub mod quantized_stable_lm;
|
||||
pub mod quantized_t5;
|
||||
pub mod qwen2;
|
||||
pub mod qwen2_moe;
|
||||
pub mod recurrent_gemma;
|
||||
pub mod repvgg;
|
||||
pub mod resnet;
|
||||
pub mod rwkv_v5;
|
||||
|
@ -104,8 +104,8 @@ impl GroupedQueryAttention {
|
||||
};
|
||||
self.kv_cache = Some((key.clone(), value.clone()));
|
||||
let query = query.contiguous()?;
|
||||
let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
||||
let attn_bias = {
|
||||
let s_q = query.dim(D::Minus2)?;
|
||||
@ -134,20 +134,6 @@ impl GroupedQueryAttention {
|
||||
}
|
||||
}
|
||||
|
||||
// This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||
// The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
|
||||
// (batch, num_attention_heads, seqlen, head_dim)
|
||||
pub(crate) fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Ffn {
|
||||
up_proj: Linear,
|
||||
|
@ -174,15 +174,7 @@ impl Attention {
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_heads / self.num_kv_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
crate::utils::repeat_kv(xs, self.num_heads / self.num_kv_heads)
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
|
@ -205,9 +205,9 @@ impl LayerWeights {
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
// Support for MQA, useful for 70B models.
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
// Support for MQA, useful for 70B models and mistral.
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = match mask {
|
||||
@ -224,20 +224,6 @@ impl LayerWeights {
|
||||
let y = self.attention_wo.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_head / self.n_kv_head;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -235,6 +235,7 @@ pub mod transformer {
|
||||
xs = layer.forward(&xs, pos, &mask)?
|
||||
}
|
||||
xs.narrow(1, seqlen - 1, 1)?
|
||||
.contiguous()?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.output)
|
||||
}
|
||||
|
@ -122,18 +122,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -172,8 +160,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
|
@ -71,8 +71,8 @@ impl GroupedQueryAttention {
|
||||
};
|
||||
self.kv_cache = Some((key.clone(), value.clone()));
|
||||
let query = query.contiguous()?;
|
||||
let key = super::mpt::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let value = super::mpt::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let key = crate::utils::repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let value = crate::utils::repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
||||
let attn_bias = {
|
||||
let s_q = query.dim(D::Minus2)?;
|
||||
|
412
candle-transformers/src/models/quantized_recurrent_gemma.rs
Normal file
412
candle-transformers/src/models/quantized_recurrent_gemma.rs
Normal file
@ -0,0 +1,412 @@
|
||||
use crate::quantized_nn::{linear_b as linear, Embedding, Linear};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::models::recurrent_gemma::{Config, Rglru, RmsNorm, RotaryEmbedding, TemporalBlockType};
|
||||
|
||||
fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
|
||||
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
||||
Ok(RmsNorm::from_weight(weight, eps))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let intermediate_size = cfg.intermediate_size / 2;
|
||||
let gate_proj = linear(h, intermediate_size, true, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear(h, intermediate_size, true, vb.pp("up_proj"))?;
|
||||
let down_proj = linear(intermediate_size, h, true, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let gate = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
(gate * xs.apply(&self.up_proj))?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
fn rglru(cfg: &Config, vb: VarBuilder) -> Result<Rglru> {
|
||||
let h = cfg.hidden_size;
|
||||
let lru_width = cfg.lru_width.unwrap_or(h);
|
||||
let n_heads = cfg.num_attention_heads;
|
||||
let block_width = lru_width / n_heads;
|
||||
let recurrent_param = vb.get((lru_width,), "recurrent_param")?;
|
||||
let input_gate_weight = vb.get((n_heads, block_width, block_width), "input_gate_weight")?;
|
||||
let input_gate_bias = vb.get((n_heads, block_width), "input_gate_bias")?;
|
||||
let recurrent_gate_weight =
|
||||
vb.get((n_heads, block_width, block_width), "recurrent_gate_weight")?;
|
||||
let recurrent_gate_bias = vb.get((n_heads, block_width), "recurrent_gate_bias")?;
|
||||
Ok(Rglru {
|
||||
recurrent_param: recurrent_param.dequantize(vb.device())?,
|
||||
input_gate_bias: input_gate_bias.dequantize(vb.device())?,
|
||||
input_gate_weight: input_gate_weight.dequantize(vb.device())?,
|
||||
recurrent_gate_bias: recurrent_gate_bias.dequantize(vb.device())?,
|
||||
recurrent_gate_weight: recurrent_gate_weight.dequantize(vb.device())?,
|
||||
block_width,
|
||||
n_heads,
|
||||
recurrent_states: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RecurrentBlock {
|
||||
linear_y: Linear,
|
||||
linear_x: Linear,
|
||||
linear_out: Linear,
|
||||
conv_1d: candle_nn::Conv1d,
|
||||
conv1d_state: Option<Tensor>,
|
||||
conv1d_width: usize,
|
||||
rg_lru: Rglru,
|
||||
act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl RecurrentBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let lru_width = cfg.lru_width.unwrap_or(h);
|
||||
let linear_y = linear(h, lru_width, true, vb.pp("linear_y"))?;
|
||||
let linear_x = linear(h, lru_width, true, vb.pp("linear_x"))?;
|
||||
let linear_out = linear(lru_width, h, true, vb.pp("linear_out"))?;
|
||||
|
||||
let conv_1d = {
|
||||
let ws = vb
|
||||
.get((lru_width, 1, cfg.conv1d_width), "conv_1d.weight")?
|
||||
.dequantize(vb.device())?;
|
||||
let bs = vb.get(lru_width, "conv_1d.bias")?.dequantize(vb.device())?;
|
||||
let config = candle_nn::Conv1dConfig {
|
||||
groups: lru_width,
|
||||
padding: cfg.conv1d_width - 1,
|
||||
..Default::default()
|
||||
};
|
||||
candle_nn::Conv1d::new(ws, Some(bs), config)
|
||||
};
|
||||
let rg_lru = rglru(cfg, vb.pp("rg_lru"))?;
|
||||
Ok(Self {
|
||||
linear_y,
|
||||
linear_x,
|
||||
linear_out,
|
||||
conv_1d,
|
||||
conv1d_state: None,
|
||||
conv1d_width: cfg.conv1d_width,
|
||||
rg_lru,
|
||||
act_fn: cfg.hidden_activation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len, _) = xs.dims3()?;
|
||||
|
||||
let y_branch = xs.apply(&self.linear_y)?.apply(&self.act_fn)?;
|
||||
let x_branch = xs.apply(&self.linear_x)?.transpose(1, 2)?;
|
||||
let x_branch = if pos == 0 {
|
||||
let x_len = x_branch.dim(D::Minus1)?;
|
||||
let pad = self.conv1d_width as i64 - x_len as i64 - 1;
|
||||
let padded = match pad.cmp(&0) {
|
||||
std::cmp::Ordering::Equal => x_branch.clone(),
|
||||
std::cmp::Ordering::Less => {
|
||||
let rev_pad = (-pad) as usize;
|
||||
x_branch.narrow(D::Minus1, rev_pad, x_len - rev_pad)?
|
||||
}
|
||||
std::cmp::Ordering::Greater => {
|
||||
x_branch.pad_with_zeros(D::Minus1, pad as usize, 0)?
|
||||
}
|
||||
};
|
||||
self.conv1d_state = Some(padded);
|
||||
x_branch
|
||||
.apply(&self.conv_1d)?
|
||||
.narrow(D::Minus1, 0, seq_len)?
|
||||
} else {
|
||||
let conv_state = match self.conv1d_state.as_ref() {
|
||||
None => candle::bail!("empty cache despite pos > 0"),
|
||||
Some(s) => Tensor::cat(&[s, &x_branch], D::Minus1)?,
|
||||
};
|
||||
let w = self.conv_1d.weight().i((.., 0, ..))?;
|
||||
let x_branch = conv_state.broadcast_mul(&w)?.sum(D::Minus1)?;
|
||||
let x_branch = match self.conv_1d.bias() {
|
||||
None => x_branch,
|
||||
Some(b) => x_branch.broadcast_add(b)?,
|
||||
};
|
||||
let x_branch = x_branch.unsqueeze(D::Minus1)?;
|
||||
self.conv1d_state = Some(conv_state.i((.., .., 1..))?);
|
||||
x_branch
|
||||
};
|
||||
let x_branch = x_branch.transpose(1, 2)?;
|
||||
let x_branch = self.rg_lru.forward(&x_branch, pos)?;
|
||||
(x_branch * y_branch)?.apply(&self.linear_out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SdpaAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
n_heads: usize,
|
||||
n_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
}
|
||||
|
||||
impl SdpaAttention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let n_heads = cfg.num_attention_heads;
|
||||
let n_kv_heads = cfg.num_key_value_heads;
|
||||
let hd = cfg.head_dim;
|
||||
let q_proj = linear(h, n_heads * hd, cfg.attention_bias, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(n_heads * hd, h, true, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
head_dim: hd,
|
||||
hidden_size: h,
|
||||
kv_cache: None,
|
||||
rotary_emb,
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_heads / self.n_kv_heads;
|
||||
crate::utils::repeat_kv(x, n_rep)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (bsz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = xs.apply(&self.q_proj)?;
|
||||
let key_states = xs.apply(&self.k_proj)?;
|
||||
let value_states = xs.apply(&self.v_proj)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((bsz, q_len, self.n_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let query_states = query_states.chunk(2, D::Minus1)?;
|
||||
let key_states = key_states.chunk(2, D::Minus1)?;
|
||||
let (query_rot, key_rot) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states[0], &key_states[0], pos)?;
|
||||
let query_states = Tensor::cat(&[&query_rot, &query_states[1]], D::Minus1)?.contiguous()?;
|
||||
let key_states = Tensor::cat(&[&key_rot, &key_states[1]], D::Minus1)?.contiguous()?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let xs = {
|
||||
let att = (query_states.matmul(&key_states.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = if q_len == 1 {
|
||||
att
|
||||
} else {
|
||||
match attention_mask {
|
||||
None => att,
|
||||
Some(mask) => att.broadcast_add(mask)?,
|
||||
}
|
||||
};
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
att.matmul(&value_states.contiguous()?)?
|
||||
};
|
||||
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.reshape((bsz, q_len, self.hidden_size))?;
|
||||
self.o_proj.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum TemporalBlock {
|
||||
Recurrent(RecurrentBlock),
|
||||
Attention(SdpaAttention),
|
||||
}
|
||||
|
||||
impl TemporalBlock {
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Recurrent(b) => b.forward(xs, pos),
|
||||
Self::Attention(b) => b.forward(xs, attention_mask, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
temporal_pre_norm: RmsNorm,
|
||||
channel_pre_norm: RmsNorm,
|
||||
temporal_block: TemporalBlock,
|
||||
mlp_block: Mlp,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
block_idx: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let temporal_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("temporal_pre_norm"))?;
|
||||
let channel_pre_norm = rms_norm(h, cfg.rms_norm_eps, vb.pp("channel_pre_norm"))?;
|
||||
let temporal_block = match cfg.block_types[block_idx % cfg.block_types.len()] {
|
||||
TemporalBlockType::Recurrent => {
|
||||
let block = RecurrentBlock::new(cfg, vb.pp("temporal_block"))?;
|
||||
TemporalBlock::Recurrent(block)
|
||||
}
|
||||
TemporalBlockType::Attention => {
|
||||
let block = SdpaAttention::new(rotary_emb, cfg, vb.pp("temporal_block"))?;
|
||||
TemporalBlock::Attention(block)
|
||||
}
|
||||
};
|
||||
let mlp_block = Mlp::new(cfg, vb.pp("mlp_block"))?;
|
||||
Ok(Self {
|
||||
temporal_pre_norm,
|
||||
channel_pre_norm,
|
||||
temporal_block,
|
||||
mlp_block,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.temporal_pre_norm)?;
|
||||
let xs = self.temporal_block.forward(&xs, attention_mask, pos)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.channel_pre_norm)?.apply(&self.mlp_block)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
final_norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
hidden_size: usize,
|
||||
logits_soft_cap: f64,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_tokens = Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(DType::F32, cfg, vb.device())?);
|
||||
let vb_b = vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(idx, rotary_emb.clone(), cfg, vb_b.pp(idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let final_norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("final_norm"))?;
|
||||
let lm_head = linear(
|
||||
cfg.hidden_size,
|
||||
cfg.vocab_size,
|
||||
false,
|
||||
vb.pp("embed_tokens"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
final_norm,
|
||||
lm_head,
|
||||
hidden_size: cfg.hidden_size,
|
||||
logits_soft_cap: cfg.logits_soft_cap,
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = xs.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, pos)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = xs.apply(&self.embed_tokens)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), pos)?;
|
||||
}
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.final_norm)?
|
||||
.apply(&self.lm_head)?;
|
||||
let logits = ((logits / self.logits_soft_cap)?.tanh()? * self.logits_soft_cap)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
@ -94,18 +94,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -152,8 +140,9 @@ impl Attention {
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
}
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
|
@ -146,18 +146,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -194,8 +182,9 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
|
@ -151,18 +151,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -199,8 +187,9 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
|
643
candle-transformers/src/models/recurrent_gemma.rs
Normal file
643
candle-transformers/src/models/recurrent_gemma.rs
Normal file
@ -0,0 +1,643 @@
|
||||
// This implementation is based on the python version from huggingface/transformers.
|
||||
// https://github.com/huggingface/transformers/blob/b109257f4fb8b1166e7c53cc5418632014ed53a5/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L2
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b as linear, Linear, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone, Copy)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TemporalBlockType {
|
||||
Attention,
|
||||
Recurrent,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub num_hidden_layers: usize,
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub head_dim: usize,
|
||||
pub lru_width: Option<usize>,
|
||||
pub attention_window_size: usize,
|
||||
pub conv1d_width: usize,
|
||||
pub logits_soft_cap: f64,
|
||||
pub hidden_activation: candle_nn::Activation,
|
||||
pub partial_rotary_factor: f64,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
#[serde(alias = "_block_types")]
|
||||
pub block_types: Vec<TemporalBlockType>,
|
||||
pub attention_bias: bool,
|
||||
#[serde(default = "default_max_seq_len")]
|
||||
pub max_seq_len: usize,
|
||||
}
|
||||
|
||||
fn default_max_seq_len() -> usize {
|
||||
8192
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct RmsNorm {
|
||||
weight: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
pub(crate) fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(dim, "weight")?;
|
||||
Ok(Self { weight, eps })
|
||||
}
|
||||
|
||||
pub(crate) fn from_weight(weight: Tensor, eps: f64) -> Self {
|
||||
Self { weight, eps }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&(&self.weight + 1.0)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
if cfg.partial_rotary_factor != 0.5 {
|
||||
candle::bail!("partial-rotary-factor {} <> 0.5", cfg.partial_rotary_factor)
|
||||
}
|
||||
let dim = cfg.head_dim / 2;
|
||||
let max_seq_len = cfg.max_seq_len;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let intermediate_size = cfg.intermediate_size / 2;
|
||||
let gate_proj = linear(h, intermediate_size, true, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear(h, intermediate_size, true, vb.pp("up_proj"))?;
|
||||
let down_proj = linear(intermediate_size, h, true, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let gate = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
(gate * xs.apply(&self.up_proj))?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
// Real-Gated Linear Recurrent Unit
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Rglru {
|
||||
pub(crate) recurrent_param: Tensor,
|
||||
pub(crate) input_gate_weight: Tensor,
|
||||
pub(crate) input_gate_bias: Tensor,
|
||||
pub(crate) recurrent_gate_weight: Tensor,
|
||||
pub(crate) recurrent_gate_bias: Tensor,
|
||||
pub(crate) block_width: usize,
|
||||
pub(crate) n_heads: usize,
|
||||
pub(crate) recurrent_states: Option<Tensor>,
|
||||
}
|
||||
|
||||
fn baddbmm(a: &Tensor, b: &Tensor, c: &Tensor) -> Result<Tensor> {
|
||||
a.broadcast_add(&b.matmul(c)?)
|
||||
}
|
||||
|
||||
fn softplus(xs: &Tensor) -> Result<Tensor> {
|
||||
(xs.exp()? + 1.0)?.log()
|
||||
}
|
||||
|
||||
impl Rglru {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let lru_width = cfg.lru_width.unwrap_or(h);
|
||||
let n_heads = cfg.num_attention_heads;
|
||||
let block_width = lru_width / n_heads;
|
||||
let recurrent_param = vb.get((lru_width,), "recurrent_param")?;
|
||||
let input_gate_weight = vb.get((n_heads, block_width, block_width), "input_gate_weight")?;
|
||||
let input_gate_bias = vb.get((n_heads, block_width), "input_gate_bias")?;
|
||||
let recurrent_gate_weight =
|
||||
vb.get((n_heads, block_width, block_width), "recurrent_gate_weight")?;
|
||||
let recurrent_gate_bias = vb.get((n_heads, block_width), "recurrent_gate_bias")?;
|
||||
Ok(Self {
|
||||
recurrent_param,
|
||||
input_gate_bias,
|
||||
input_gate_weight,
|
||||
recurrent_gate_bias,
|
||||
recurrent_gate_weight,
|
||||
block_width,
|
||||
n_heads,
|
||||
recurrent_states: None,
|
||||
})
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py#L303
|
||||
pub(crate) fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len, lru_width) = xs.dims3()?;
|
||||
let pos = Tensor::arange(pos as u32, (pos + seq_len) as u32, xs.device())?;
|
||||
let reset = pos.eq(0u32)?.unsqueeze(1)?.unsqueeze(0)?;
|
||||
let reshape_act = xs
|
||||
.reshape((b_sz * seq_len, self.n_heads, self.block_width))?
|
||||
.permute((1, 0, 2))?
|
||||
.contiguous()?;
|
||||
|
||||
let res = baddbmm(
|
||||
&self.input_gate_bias.unsqueeze(1)?,
|
||||
&reshape_act,
|
||||
&self.input_gate_weight,
|
||||
)?;
|
||||
let input_gate = res.transpose(0, 1)?.reshape((b_sz, seq_len, lru_width))?;
|
||||
let input_gate = candle_nn::ops::sigmoid(&input_gate)?;
|
||||
let res = baddbmm(
|
||||
&self.recurrent_gate_bias.unsqueeze(1)?,
|
||||
&reshape_act,
|
||||
&self.recurrent_gate_weight,
|
||||
)?;
|
||||
let recurrent_gate = res.transpose(0, 1)?.reshape((b_sz, seq_len, lru_width))?;
|
||||
let recurrent_gate = candle_nn::ops::sigmoid(&recurrent_gate)?;
|
||||
|
||||
let log_recurrent_gate =
|
||||
(recurrent_gate * (-8.0))?.broadcast_mul(&softplus(&self.recurrent_param)?)?;
|
||||
let recurrent_gate = log_recurrent_gate.exp()?;
|
||||
let a_square = (log_recurrent_gate * 2.)?.exp()?;
|
||||
|
||||
// Gate the input.
|
||||
let gated_inputs = (xs * input_gate)?;
|
||||
|
||||
let reset = reset.to_dtype(a_square.dtype())?;
|
||||
let multiplier =
|
||||
reset.broadcast_add(&((1.0 - &reset)?.broadcast_mul(&(1.0 - a_square)?.sqrt()?))?)?;
|
||||
let normalized_x = (gated_inputs * multiplier.to_dtype(xs.dtype()))?;
|
||||
|
||||
let (hidden_states, recurrent_states) = rnn_scan(
|
||||
&normalized_x,
|
||||
&recurrent_gate,
|
||||
&reset,
|
||||
self.recurrent_states.as_ref(),
|
||||
)?;
|
||||
self.recurrent_states = Some(recurrent_states);
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
fn rnn_scan(
|
||||
hidden_states: &Tensor,
|
||||
recurrent_gate: &Tensor,
|
||||
reset: &Tensor,
|
||||
recurrent_states: Option<&Tensor>,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let acc_dtype = DType::F32;
|
||||
let dev = hidden_states.device();
|
||||
let in_dtype = hidden_states.dtype();
|
||||
let inv_reset = (1.0 - reset)?.to_dtype(recurrent_gate.dtype())?;
|
||||
let recurrent_gate = recurrent_gate.broadcast_mul(&inv_reset)?;
|
||||
let (c, r) = if hidden_states.dim(1)? == 1 {
|
||||
match recurrent_states {
|
||||
None => {
|
||||
let next_state = hidden_states.i((.., 0))?.to_dtype(acc_dtype)?;
|
||||
(hidden_states.clone(), next_state)
|
||||
}
|
||||
Some(recurrent_states) => {
|
||||
let contextualized_states =
|
||||
recurrent_gate.to_dtype(acc_dtype)? * recurrent_states.unsqueeze(1)?;
|
||||
let contextualized_states =
|
||||
(contextualized_states + hidden_states.to_dtype(acc_dtype)?)?;
|
||||
let c = contextualized_states.to_dtype(in_dtype)?;
|
||||
let l = contextualized_states.dim(1)?;
|
||||
let r = contextualized_states.i((.., l - 1))?;
|
||||
(c, r)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let mut recurrent_states = match recurrent_states {
|
||||
None => Tensor::zeros(hidden_states.i((.., 0))?.shape(), acc_dtype, dev)?,
|
||||
Some(r) => r.clone(),
|
||||
};
|
||||
let mut contextualized_states = vec![];
|
||||
for t in 0..hidden_states.dim(1)? {
|
||||
recurrent_states =
|
||||
(recurrent_gate.i((.., t))?.to_dtype(acc_dtype)? * recurrent_states)?;
|
||||
recurrent_states =
|
||||
(recurrent_states + hidden_states.i((.., t))?.to_dtype(acc_dtype)?)?;
|
||||
contextualized_states.push(recurrent_states.to_dtype(in_dtype)?)
|
||||
}
|
||||
let contextualized_states = Tensor::stack(&contextualized_states, 1)?;
|
||||
(contextualized_states, recurrent_states)
|
||||
};
|
||||
Ok((c, r))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RecurrentBlock {
|
||||
linear_y: Linear,
|
||||
linear_x: Linear,
|
||||
linear_out: Linear,
|
||||
conv_1d: candle_nn::Conv1d,
|
||||
conv1d_state: Option<Tensor>,
|
||||
conv1d_width: usize,
|
||||
rg_lru: Rglru,
|
||||
act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl RecurrentBlock {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let lru_width = cfg.lru_width.unwrap_or(h);
|
||||
let linear_y = linear(h, lru_width, true, vb.pp("linear_y"))?;
|
||||
let linear_x = linear(h, lru_width, true, vb.pp("linear_x"))?;
|
||||
let linear_out = linear(lru_width, h, true, vb.pp("linear_out"))?;
|
||||
let conv_1d = candle_nn::conv1d(
|
||||
lru_width,
|
||||
lru_width,
|
||||
cfg.conv1d_width,
|
||||
candle_nn::Conv1dConfig {
|
||||
groups: lru_width,
|
||||
padding: cfg.conv1d_width - 1,
|
||||
..Default::default()
|
||||
},
|
||||
vb.pp("conv_1d"),
|
||||
)?;
|
||||
let rg_lru = Rglru::new(cfg, vb.pp("rg_lru"))?;
|
||||
Ok(Self {
|
||||
linear_y,
|
||||
linear_x,
|
||||
linear_out,
|
||||
conv_1d,
|
||||
conv1d_state: None,
|
||||
conv1d_width: cfg.conv1d_width,
|
||||
rg_lru,
|
||||
act_fn: cfg.hidden_activation,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len, _) = xs.dims3()?;
|
||||
|
||||
let y_branch = xs.apply(&self.linear_y)?.apply(&self.act_fn)?;
|
||||
let x_branch = xs.apply(&self.linear_x)?.transpose(1, 2)?;
|
||||
let x_branch = if pos == 0 {
|
||||
let x_len = x_branch.dim(D::Minus1)?;
|
||||
let pad = self.conv1d_width as i64 - x_len as i64 - 1;
|
||||
let padded = match pad.cmp(&0) {
|
||||
std::cmp::Ordering::Equal => x_branch.clone(),
|
||||
std::cmp::Ordering::Less => {
|
||||
let rev_pad = (-pad) as usize;
|
||||
x_branch.narrow(D::Minus1, rev_pad, x_len - rev_pad)?
|
||||
}
|
||||
std::cmp::Ordering::Greater => {
|
||||
x_branch.pad_with_zeros(D::Minus1, pad as usize, 0)?
|
||||
}
|
||||
};
|
||||
self.conv1d_state = Some(padded);
|
||||
x_branch
|
||||
.apply(&self.conv_1d)?
|
||||
.narrow(D::Minus1, 0, seq_len)?
|
||||
} else {
|
||||
let conv_state = match self.conv1d_state.as_ref() {
|
||||
None => candle::bail!("empty cache despite pos > 0"),
|
||||
Some(s) => Tensor::cat(&[s, &x_branch], D::Minus1)?,
|
||||
};
|
||||
let w = self.conv_1d.weight().i((.., 0, ..))?;
|
||||
let x_branch = conv_state.broadcast_mul(&w)?.sum(D::Minus1)?;
|
||||
let x_branch = match self.conv_1d.bias() {
|
||||
None => x_branch,
|
||||
Some(b) => x_branch.broadcast_add(b)?,
|
||||
};
|
||||
let x_branch = x_branch.unsqueeze(D::Minus1)?;
|
||||
self.conv1d_state = Some(conv_state.i((.., .., 1..))?);
|
||||
x_branch
|
||||
};
|
||||
let x_branch = x_branch.transpose(1, 2)?;
|
||||
let x_branch = self.rg_lru.forward(&x_branch, pos)?;
|
||||
(x_branch * y_branch)?.apply(&self.linear_out)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SdpaAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
n_heads: usize,
|
||||
n_kv_heads: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
}
|
||||
|
||||
impl SdpaAttention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let n_heads = cfg.num_attention_heads;
|
||||
let n_kv_heads = cfg.num_key_value_heads;
|
||||
let hd = cfg.head_dim;
|
||||
let q_proj = linear(h, n_heads * hd, cfg.attention_bias, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(h, n_kv_heads * hd, cfg.attention_bias, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(n_heads * hd, h, true, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
head_dim: hd,
|
||||
hidden_size: h,
|
||||
kv_cache: None,
|
||||
rotary_emb,
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.n_heads / self.n_kv_heads;
|
||||
crate::utils::repeat_kv(x, n_rep)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (bsz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = xs.apply(&self.q_proj)?;
|
||||
let key_states = xs.apply(&self.k_proj)?;
|
||||
let value_states = xs.apply(&self.v_proj)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((bsz, q_len, self.n_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((bsz, q_len, self.n_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let query_states = query_states.chunk(2, D::Minus1)?;
|
||||
let key_states = key_states.chunk(2, D::Minus1)?;
|
||||
let (query_rot, key_rot) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states[0], &key_states[0], pos)?;
|
||||
let query_states = Tensor::cat(&[&query_rot, &query_states[1]], D::Minus1)?.contiguous()?;
|
||||
let key_states = Tensor::cat(&[&key_rot, &key_states[1]], D::Minus1)?.contiguous()?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let xs = {
|
||||
let att = (query_states.matmul(&key_states.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = if q_len == 1 {
|
||||
att
|
||||
} else {
|
||||
match attention_mask {
|
||||
None => att,
|
||||
Some(mask) => att.broadcast_add(mask)?,
|
||||
}
|
||||
};
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
att.matmul(&value_states.contiguous()?)?
|
||||
};
|
||||
|
||||
let xs = xs
|
||||
.transpose(1, 2)?
|
||||
.reshape((bsz, q_len, self.hidden_size))?;
|
||||
self.o_proj.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum TemporalBlock {
|
||||
Recurrent(RecurrentBlock),
|
||||
Attention(SdpaAttention),
|
||||
}
|
||||
|
||||
impl TemporalBlock {
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Recurrent(b) => b.forward(xs, pos),
|
||||
Self::Attention(b) => b.forward(xs, attention_mask, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
temporal_pre_norm: RmsNorm,
|
||||
channel_pre_norm: RmsNorm,
|
||||
temporal_block: TemporalBlock,
|
||||
mlp_block: Mlp,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
block_idx: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let temporal_pre_norm = RmsNorm::new(h, cfg.rms_norm_eps, vb.pp("temporal_pre_norm"))?;
|
||||
let channel_pre_norm = RmsNorm::new(h, cfg.rms_norm_eps, vb.pp("channel_pre_norm"))?;
|
||||
let temporal_block = match cfg.block_types[block_idx % cfg.block_types.len()] {
|
||||
TemporalBlockType::Recurrent => {
|
||||
let block = RecurrentBlock::new(cfg, vb.pp("temporal_block"))?;
|
||||
TemporalBlock::Recurrent(block)
|
||||
}
|
||||
TemporalBlockType::Attention => {
|
||||
let block = SdpaAttention::new(rotary_emb, cfg, vb.pp("temporal_block"))?;
|
||||
TemporalBlock::Attention(block)
|
||||
}
|
||||
};
|
||||
let mlp_block = Mlp::new(cfg, vb.pp("mlp_block"))?;
|
||||
Ok(Self {
|
||||
temporal_pre_norm,
|
||||
channel_pre_norm,
|
||||
temporal_block,
|
||||
mlp_block,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.temporal_pre_norm)?;
|
||||
let xs = self.temporal_block.forward(&xs, attention_mask, pos)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.channel_pre_norm)?.apply(&self.mlp_block)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
final_norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
hidden_size: usize,
|
||||
logits_soft_cap: f64,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
|
||||
let vb_b = vb.pp("layers");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(idx, rotary_emb.clone(), cfg, vb_b.pp(idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let final_norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("final_norm"))?;
|
||||
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
final_norm,
|
||||
lm_head,
|
||||
hidden_size: cfg.hidden_size,
|
||||
logits_soft_cap: cfg.logits_soft_cap,
|
||||
dtype: vb.dtype(),
|
||||
device: vb.device().clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = xs.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, pos)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = xs.apply(&self.embed_tokens)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), pos)?;
|
||||
}
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.final_norm)?
|
||||
.apply(&self.lm_head)?;
|
||||
let logits = ((logits / self.logits_soft_cap)?.tanh()? * self.logits_soft_cap)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
@ -217,18 +217,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -275,8 +263,9 @@ impl Attention {
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
}
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
|
@ -139,18 +139,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -187,8 +175,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
@ -183,7 +183,7 @@ impl Module for T5LayerNorm {
|
||||
let xs_f32 = xs.to_dtype(DType::F32)?;
|
||||
// variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?;
|
||||
let xs = xs.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?;
|
||||
let xs = xs.to_dtype(dtype)?;
|
||||
let xs = xs.broadcast_mul(&self.weight)?;
|
||||
Ok(xs)
|
||||
@ -709,8 +709,10 @@ impl T5EncoderModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let shared_vb = if vb.contains_tensor("shared.weight") {
|
||||
vb.pp("shared")
|
||||
} else {
|
||||
} else if vb.contains_tensor("decoder.embed_tokens") {
|
||||
vb.pp("decoder").pp("embed_tokens")
|
||||
} else {
|
||||
vb.pp("encoder").pp("embed_tokens")
|
||||
};
|
||||
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
|
||||
let shared = Arc::new(shared);
|
||||
|
@ -175,18 +175,6 @@ impl Attention {
|
||||
})
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_kv_groups;
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
||||
xs.unsqueeze(2)?
|
||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
@ -223,8 +211,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = self.repeat_kv(key_states)?;
|
||||
let value_states = self.repeat_kv(value_states)?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
|
@ -63,7 +63,7 @@ impl VarBuilder {
|
||||
let path = self.path(name);
|
||||
match self.data.get(&path) {
|
||||
None => {
|
||||
candle::bail!("cannot find tensor {name}")
|
||||
candle::bail!("cannot find tensor {path}")
|
||||
}
|
||||
Some(qtensor) => {
|
||||
let shape = s.into();
|
||||
|
@ -2,7 +2,7 @@ use candle::{Result, Tensor};
|
||||
|
||||
pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> Result<Tensor> {
|
||||
let device = logits.device();
|
||||
let mut logits = logits.to_vec1::<f32>()?;
|
||||
let mut logits = logits.to_dtype(candle::DType::F32)?.to_vec1::<f32>()?;
|
||||
let mut already_seen = std::collections::HashSet::new();
|
||||
for token_id in context {
|
||||
if already_seen.contains(token_id) {
|
||||
@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R
|
||||
let logits_len = logits.len();
|
||||
Tensor::from_vec(logits, logits_len, device)
|
||||
}
|
||||
|
||||
/// Repeats a key or value tensor for grouped query attention
|
||||
/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
|
||||
pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
|
||||
if n_rep == 1 {
|
||||
Ok(xs)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
|
||||
// Using cat is faster than a broadcast as it avoids going through a potentially
|
||||
// strided copy.
|
||||
// https://github.com/huggingface/candle/pull/2043
|
||||
Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user