mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
3 Commits
improve-sa
...
phi2-gguf
Author | SHA1 | Date | |
---|---|---|---|
3754b834f4 | |||
d79041d94d | |||
af11b2d461 |
22
Cargo.toml
22
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.5.1"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.5.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.5.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.5.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.5.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.5.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.5.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.5.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.5.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.5.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.5.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.5.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.5.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
@ -65,13 +65,13 @@ serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.19.1", default-features = false }
|
||||
tokenizers = { version = "0.15.0", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "1.1.1", default-features = false }
|
||||
zip = { version = "0.6.6", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
||||
[profile.release-with-debug]
|
||||
|
15
README.md
15
README.md
@ -60,14 +60,13 @@ These online demos run entirely in your browser:
|
||||
|
||||
We also provide a some command line based examples using state of the art models:
|
||||
|
||||
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
|
||||
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
|
||||
Griffin based models from Google that mix attention with a RNN like state.
|
||||
- [Phi-1, Phi-1.5, Phi-2, and Phi-3](./candle-examples/examples/phi/): 1.3b,
|
||||
2.7b, and 3.8b general LLMs with performance on par with 7b models.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
||||
@ -112,7 +111,7 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmentation model.
|
||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
||||
model using residual vector quantization.
|
||||
@ -201,10 +200,10 @@ If you have an addition to this list, please submit a pull request.
|
||||
- WASM support, run your models in a browser.
|
||||
- Included models.
|
||||
- Language Models.
|
||||
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
|
||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
||||
- Falcon.
|
||||
- StarCoder, StarCoder2.
|
||||
- Phi 1, 1.5, 2, and 3.
|
||||
- Phi 1, 1.5, and 2.
|
||||
- Mamba, Minimal Mamba
|
||||
- Gemma 2b and 7b.
|
||||
- Mistral 7b v0.1.
|
||||
@ -376,9 +375,9 @@ git submodule update --init
|
||||
/usr/include/c++/11/bits/std_function.h:530:146: error: parameter packs not expanded with ‘...’:
|
||||
```
|
||||
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the NVCC_CCBIN environment variable.
|
||||
This is a bug in gcc-11 triggered by the Cuda compiler. To fix this, install a different, supported gcc version - for example gcc-10, and specify the path to the compiler in the CANDLE_NVCC_CCBIN environment variable.
|
||||
```
|
||||
env NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
env CANDLE_NVCC_CCBIN=/usr/lib/gcc/x86_64-linux-gnu/10 cargo ...
|
||||
```
|
||||
|
||||
#### Linking error on windows when running rustdoc or mdbook tests
|
||||
|
@ -81,7 +81,7 @@ let mut tp_shape = view.shape().to_vec();
|
||||
let size = tp_shape[0];
|
||||
|
||||
if size % world_size != 0 {
|
||||
panic!("The dimension is not divisible by `world_size`");
|
||||
panic!("The dimension is not divisble by `world_size`");
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
|
@ -8,5 +8,4 @@ criterion_main!(
|
||||
benchmarks::where_cond::benches,
|
||||
benchmarks::conv_transpose2d::benches,
|
||||
benchmarks::qmatmul::benches,
|
||||
benchmarks::unary::benches
|
||||
);
|
||||
|
@ -3,7 +3,6 @@ pub(crate) mod conv_transpose2d;
|
||||
pub(crate) mod matmul;
|
||||
pub(crate) mod qmatmul;
|
||||
pub(crate) mod random;
|
||||
pub(crate) mod unary;
|
||||
pub(crate) mod where_cond;
|
||||
|
||||
use candle_core::{Device, Result};
|
||||
|
@ -1,49 +0,0 @@
|
||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
||||
use candle_core::{DType, Device, Tensor};
|
||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
||||
use std::time::Instant;
|
||||
|
||||
fn run(a: &Tensor) {
|
||||
a.sqrt().unwrap();
|
||||
}
|
||||
|
||||
fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||
let b = 1;
|
||||
let m = 1024;
|
||||
let k = 1024;
|
||||
|
||||
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
|
||||
.unwrap()
|
||||
.to_dtype(dtype)
|
||||
.unwrap()
|
||||
.reshape((b, m, k))
|
||||
.unwrap();
|
||||
|
||||
let flops = b * m * k * dtype.size_in_bytes();
|
||||
|
||||
let mut group = c.benchmark_group(device.bench_name(name));
|
||||
group.throughput(Throughput::Bytes(flops as u64));
|
||||
group.bench_function("iter", move |b| {
|
||||
b.iter_custom(|iters| {
|
||||
let start = Instant::now();
|
||||
for _i in 0..iters {
|
||||
run(black_box(&tensor));
|
||||
}
|
||||
device.sync().unwrap();
|
||||
start.elapsed()
|
||||
})
|
||||
});
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
let handler = BenchDeviceHandler::new().unwrap();
|
||||
for device in handler.devices {
|
||||
for dtype in [DType::F32, DType::BF16, DType::F16] {
|
||||
let name = format!("sqrt_{:?}", dtype);
|
||||
run_unary_benchmark(c, &device, dtype, &name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
@ -133,8 +133,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
/// after this call.
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage>;
|
||||
|
||||
fn storage_from_cpu_storage_owned(&self, _: CpuStorage) -> Result<Self::Storage>;
|
||||
|
@ -624,7 +624,7 @@ impl Tensor {
|
||||
Op::Unary(arg, UnaryOp::Silu) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
|
||||
let sigmoid_arg = (arg.neg()?.exp()? + 1.)?.recip()?;
|
||||
let sigmoid_arg = (*node / arg)?;
|
||||
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
pub mod erf;
|
||||
pub mod kernels;
|
||||
|
||||
#[allow(unused)]
|
||||
trait Cpu<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
@ -19,7 +18,6 @@ trait Cpu<const ARR: usize> {
|
||||
unsafe fn vec_store(mem_addr: *mut f32, a: Self::Unit);
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
trait CpuF16<const ARR: usize> {
|
||||
type Unit;
|
||||
type Array;
|
||||
|
@ -26,17 +26,6 @@ pub enum CpuStorage {
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CpuStorageRef<'a> {
|
||||
U8(&'a [u8]),
|
||||
U32(&'a [u32]),
|
||||
I64(&'a [i64]),
|
||||
BF16(&'a [bf16]),
|
||||
F16(&'a [f16]),
|
||||
F32(&'a [f32]),
|
||||
F64(&'a [f64]),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CpuDevice;
|
||||
|
||||
@ -2456,10 +2445,6 @@ impl BackendDevice for CpuDevice {
|
||||
true
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
Ok(T::to_cpu_storage(s))
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, s: &CpuStorage) -> Result<Self::Storage> {
|
||||
Ok(s.clone())
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||
@ -334,43 +334,6 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let slice = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorageRef::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorageRef::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorageRef::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorageRef::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorageRef::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorageRef::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
|
@ -18,7 +18,7 @@ pub use device::{CudaDevice, DeviceId};
|
||||
pub use error::{CudaError, WrapErr};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||
|
||||
pub enum SlicePtrOrNull<T> {
|
||||
enum SlicePtrOrNull<T> {
|
||||
Ptr(CudaSlice<T>),
|
||||
Null,
|
||||
}
|
||||
@ -33,7 +33,7 @@ unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
|
||||
}
|
||||
|
||||
impl SlicePtrOrNull<usize> {
|
||||
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
SlicePtrOrNull::Null
|
||||
} else {
|
||||
@ -250,6 +250,44 @@ impl Map1 for Powf {
|
||||
}
|
||||
}
|
||||
|
||||
struct Sum<'a>(&'a [usize]);
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let src_dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let mut dst_el = el;
|
||||
for &sum_dim in self.0.iter() {
|
||||
dst_el /= src_dims[sum_dim];
|
||||
}
|
||||
let mut sum_dims = self.0.to_vec();
|
||||
// Sort the sum_dims as they have to be processed from left to right when converting the
|
||||
// indexes.
|
||||
sum_dims.sort();
|
||||
let sum_dims_l: Vec<usize> = sum_dims.iter().map(|&d| src_dims[d]).collect();
|
||||
let sum_dims_s: Vec<usize> = sum_dims
|
||||
.iter()
|
||||
.map(|&d| src_dims[d + 1..].iter().product::<usize>())
|
||||
.collect();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev
|
||||
.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())
|
||||
.w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("sum"), kernels::REDUCE)?;
|
||||
let out = dev.alloc_zeros::<T>(dst_el).w()?;
|
||||
let params = (el, src_dims.len(), sum_dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct FastReduce<'a>(&'a [usize], ReduceOp);
|
||||
impl<'a> Map1Any for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
@ -1597,8 +1635,12 @@ impl BackendStorage for CudaStorage {
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}
|
||||
.w()?;
|
||||
CudaStorageSlice::BF16(out)
|
||||
}
|
||||
(CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => {
|
||||
@ -1606,8 +1648,12 @@ impl BackendStorage for CudaStorage {
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}
|
||||
.w()?;
|
||||
CudaStorageSlice::F16(out)
|
||||
}
|
||||
(CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => {
|
||||
@ -1810,146 +1856,3 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Default for the reduced precision setting is false, similar to pytorch.
|
||||
// https://github.com/pytorch/pytorch/issues/123157
|
||||
static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn gemm_reduced_precision_f16() -> bool {
|
||||
MM_F16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f16(b: bool) {
|
||||
MM_F16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn gemm_reduced_precision_bf16() -> bool {
|
||||
MM_BF16_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_bf16(b: bool) {
|
||||
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_f16(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<f16>,
|
||||
a: &cudarc::driver::CudaView<f16>,
|
||||
b: &cudarc::driver::CudaView<f16>,
|
||||
c: &mut CudaSlice<f16>,
|
||||
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let alpha = cfg.gemm.alpha;
|
||||
let beta = cfg.gemm.beta;
|
||||
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||
let (compute_type, alpha, beta) = if gemm_reduced_precision_f16() {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
|
||||
(&alpha) as *const f16 as *const _,
|
||||
(&beta) as *const f16 as *const _,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
(&alpha_f32) as *const f32 as *const _,
|
||||
(&beta_f32) as *const f32 as *const _,
|
||||
)
|
||||
};
|
||||
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
cfg.gemm.transb,
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.lda,
|
||||
cfg.stride_a,
|
||||
*b.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_16F,
|
||||
cfg.gemm.ldc,
|
||||
cfg.stride_c,
|
||||
cfg.batch_size,
|
||||
compute_type,
|
||||
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_bf16(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<bf16>,
|
||||
a: &cudarc::driver::CudaView<bf16>,
|
||||
b: &cudarc::driver::CudaView<bf16>,
|
||||
c: &mut CudaSlice<bf16>,
|
||||
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let alpha_f32: f32 = cfg.gemm.alpha.to_f32();
|
||||
let beta_f32: f32 = cfg.gemm.beta.to_f32();
|
||||
let alpha = f16::from_f32(alpha_f32);
|
||||
let beta = f16::from_f32(beta_f32);
|
||||
// The type for alpha and beta depends on the computeType.
|
||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex
|
||||
let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_16F,
|
||||
(&alpha) as *const f16 as *const _,
|
||||
(&beta) as *const f16 as *const _,
|
||||
)
|
||||
} else {
|
||||
(
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F,
|
||||
(&alpha_f32) as *const f32 as *const _,
|
||||
(&beta_f32) as *const f32 as *const _,
|
||||
)
|
||||
};
|
||||
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
cfg.gemm.transb,
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.lda,
|
||||
cfg.stride_a,
|
||||
*b.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_16BF,
|
||||
cfg.gemm.ldc,
|
||||
cfg.stride_c,
|
||||
cfg.batch_size,
|
||||
compute_type,
|
||||
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
)
|
||||
}
|
||||
|
@ -306,20 +306,6 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage_from_slice<D: WithDType>(&self, data: &[D]) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(data.to_cpu_storage())),
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.storage_from_slice(data)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.storage_from_slice(data)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn storage<A: NdArray>(&self, array: A) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())),
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Types for elements that can be stored and manipulated using tensors.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::{CpuStorage, CpuStorageRef, Error, Result};
|
||||
use crate::{CpuStorage, Error, Result};
|
||||
|
||||
/// The different types of elements allowed in tensors.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
@ -100,14 +100,12 @@ pub trait WithDType:
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ std::any::Any
|
||||
+ crate::cpu::kernels::VecOps
|
||||
{
|
||||
const DTYPE: DType;
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
fn to_f64(self) -> f64;
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
@ -131,10 +129,6 @@ macro_rules! with_dtype {
|
||||
$to_f64(self)
|
||||
}
|
||||
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||
CpuStorageRef::$dtype(data)
|
||||
}
|
||||
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||
CpuStorage::$dtype(data)
|
||||
}
|
||||
|
@ -214,10 +214,6 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -238,23 +234,3 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn gemm_reduced_precision_f16() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f16(_: bool) {}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn gemm_reduced_precision_bf16() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
||||
|
@ -226,10 +226,6 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, _: &[T]) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -219,14 +219,10 @@ impl Error {
|
||||
Self::Wrapped(Box::new(err)).bt()
|
||||
}
|
||||
|
||||
pub fn msg(err: impl std::error::Error) -> Self {
|
||||
pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
Self::Msg(err.to_string()).bt()
|
||||
}
|
||||
|
||||
pub fn debug(err: impl std::fmt::Debug) -> Self {
|
||||
Self::Msg(format!("{err:?}")).bt()
|
||||
}
|
||||
|
||||
pub fn bt(self) -> Self {
|
||||
let backtrace = std::backtrace::Backtrace::capture();
|
||||
match backtrace.status() {
|
||||
|
@ -47,7 +47,7 @@ mod custom_op;
|
||||
mod device;
|
||||
pub mod display;
|
||||
mod dtype;
|
||||
pub mod dummy_cuda_backend;
|
||||
mod dummy_cuda_backend;
|
||||
mod dummy_metal_backend;
|
||||
pub mod error;
|
||||
mod indexer;
|
||||
@ -63,7 +63,6 @@ pub mod quantized;
|
||||
pub mod safetensors;
|
||||
pub mod scalar;
|
||||
pub mod shape;
|
||||
mod sort;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
@ -75,7 +74,7 @@ mod variable;
|
||||
#[cfg(feature = "cudnn")]
|
||||
pub use cuda_backend::cudnn;
|
||||
|
||||
pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
||||
pub use cpu_backend::CpuStorage;
|
||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
pub use device::{Device, DeviceLocation, NdArray};
|
||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||
@ -89,12 +88,10 @@ pub use tensor::{Tensor, TensorId};
|
||||
pub use variable::Var;
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
pub use cuda_backend as cuda;
|
||||
pub use cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
pub use dummy_cuda_backend as cuda;
|
||||
|
||||
pub use cuda::{CudaDevice, CudaStorage};
|
||||
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
pub use metal_backend::{MetalDevice, MetalError, MetalStorage};
|
||||
|
@ -1,7 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
@ -11,7 +11,7 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
mod device;
|
||||
pub use device::{DeviceId, MetalDevice};
|
||||
|
||||
pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||
BufferOffset {
|
||||
buffer,
|
||||
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
|
||||
@ -444,238 +444,156 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label(B::KERNEL);
|
||||
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||
if layout.is_contiguous() {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
match (el_count % 2, dtype, layout.is_contiguous()) {
|
||||
(0, DType::BF16 | DType::F16, true) => {
|
||||
use candle_metal_kernels::unary::contiguous_tiled;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("uabs", DType::F16) => contiguous_tiled::abs::HALF,
|
||||
("uabs", DType::F32) => contiguous_tiled::abs::FLOAT,
|
||||
("uabs", DType::BF16) => contiguous_tiled::abs::BFLOAT,
|
||||
("uceil", DType::F16) => contiguous_tiled::ceil::HALF,
|
||||
("uceil", DType::F32) => contiguous_tiled::ceil::FLOAT,
|
||||
("uceil", DType::BF16) => contiguous_tiled::ceil::BFLOAT,
|
||||
("ucos", DType::F16) => contiguous_tiled::cos::HALF,
|
||||
("ucos", DType::F32) => contiguous_tiled::cos::FLOAT,
|
||||
("ucos", DType::BF16) => contiguous_tiled::cos::BFLOAT,
|
||||
("uerf", DType::F16) => contiguous_tiled::erf::HALF,
|
||||
("uerf", DType::F32) => contiguous_tiled::erf::FLOAT,
|
||||
("uerf", DType::BF16) => contiguous_tiled::erf::BFLOAT,
|
||||
("uexp", DType::F16) => contiguous_tiled::exp::HALF,
|
||||
("uexp", DType::F32) => contiguous_tiled::exp::FLOAT,
|
||||
("uexp", DType::BF16) => contiguous_tiled::exp::BFLOAT,
|
||||
("ufloor", DType::F16) => contiguous_tiled::floor::HALF,
|
||||
("ufloor", DType::F32) => contiguous_tiled::floor::FLOAT,
|
||||
("ufloor", DType::BF16) => contiguous_tiled::floor::BFLOAT,
|
||||
("ugelu_erf", DType::F16) => contiguous_tiled::gelu_erf::HALF,
|
||||
("ugelu_erf", DType::F32) => contiguous_tiled::gelu_erf::FLOAT,
|
||||
("ugelu_erf", DType::BF16) => contiguous_tiled::gelu_erf::BFLOAT,
|
||||
("ugelu", DType::F16) => contiguous_tiled::gelu::HALF,
|
||||
("ugelu", DType::F32) => contiguous_tiled::gelu::FLOAT,
|
||||
("ugelu", DType::BF16) => contiguous_tiled::gelu::BFLOAT,
|
||||
("ulog", DType::F16) => contiguous_tiled::log::HALF,
|
||||
("ulog", DType::F32) => contiguous_tiled::log::FLOAT,
|
||||
("ulog", DType::BF16) => contiguous_tiled::log::BFLOAT,
|
||||
("uneg", DType::F16) => contiguous_tiled::neg::HALF,
|
||||
("uneg", DType::F32) => contiguous_tiled::neg::FLOAT,
|
||||
("uneg", DType::BF16) => contiguous_tiled::neg::BFLOAT,
|
||||
("urecip", DType::F16) => contiguous_tiled::recip::HALF,
|
||||
("urecip", DType::F32) => contiguous_tiled::recip::FLOAT,
|
||||
("urecip", DType::BF16) => contiguous_tiled::recip::BFLOAT,
|
||||
("urelu", DType::F16) => contiguous_tiled::relu::HALF,
|
||||
("urelu", DType::F32) => contiguous_tiled::relu::FLOAT,
|
||||
("urelu", DType::BF16) => contiguous_tiled::relu::BFLOAT,
|
||||
("uround", DType::F16) => contiguous_tiled::round::HALF,
|
||||
("uround", DType::F32) => contiguous_tiled::round::FLOAT,
|
||||
("uround", DType::BF16) => contiguous_tiled::round::BFLOAT,
|
||||
("usilu", DType::F16) => contiguous_tiled::silu::HALF,
|
||||
("usilu", DType::F32) => contiguous_tiled::silu::FLOAT,
|
||||
("usilu", DType::BF16) => contiguous_tiled::silu::BFLOAT,
|
||||
("usin", DType::F16) => contiguous_tiled::sin::HALF,
|
||||
("usin", DType::F32) => contiguous_tiled::sin::FLOAT,
|
||||
("usin", DType::BF16) => contiguous_tiled::sin::BFLOAT,
|
||||
("usqr", DType::F16) => contiguous_tiled::sqr::HALF,
|
||||
("usqr", DType::F32) => contiguous_tiled::sqr::FLOAT,
|
||||
("usqr", DType::BF16) => contiguous_tiled::sqr::BFLOAT,
|
||||
("usqrt", DType::F16) => contiguous_tiled::sqrt::HALF,
|
||||
("usqrt", DType::F32) => contiguous_tiled::sqrt::FLOAT,
|
||||
("usqrt", DType::BF16) => contiguous_tiled::sqrt::BFLOAT,
|
||||
("utanh", DType::F16) => contiguous_tiled::tanh::HALF,
|
||||
("utanh", DType::F32) => contiguous_tiled::tanh::FLOAT,
|
||||
("utanh", DType::BF16) => contiguous_tiled::tanh::BFLOAT,
|
||||
("usign", DType::F16) => contiguous_tiled::sign::HALF,
|
||||
("usign", DType::F32) => contiguous_tiled::sign::FLOAT,
|
||||
("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT,
|
||||
("usign", DType::I64) => contiguous_tiled::sign::I64,
|
||||
(name, dtype) => {
|
||||
crate::bail!(
|
||||
"Metal contiguous_tiled unary {name} {dtype:?} not implemented"
|
||||
)
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous_tiled(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, true) => {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ulog", DType::BF16) => contiguous::log::BFLOAT,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
|
||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
|
||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("uround", DType::BF16) => contiguous::round::BFLOAT,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usin", DType::BF16) => contiguous::sin::BFLOAT,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
|
||||
("usign", DType::F16) => contiguous::sign::HALF,
|
||||
("usign", DType::F32) => contiguous::sign::FLOAT,
|
||||
("usign", DType::BF16) => contiguous::sign::BFLOAT,
|
||||
("usign", DType::I64) => contiguous::sign::I64,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, false) => {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||
("usin", DType::F32) => strided::sin::FLOAT,
|
||||
("usqr", DType::F32) => strided::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => strided::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => strided::neg::FLOAT,
|
||||
("uexp", DType::F32) => strided::exp::FLOAT,
|
||||
("ulog", DType::F32) => strided::log::FLOAT,
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ulog", DType::BF16) => contiguous::log::BFLOAT,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
|
||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
|
||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("uround", DType::BF16) => contiguous::round::BFLOAT,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usin", DType::BF16) => contiguous::sin::BFLOAT,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
|
||||
("usign", DType::F16) => contiguous::sign::HALF,
|
||||
("usign", DType::F32) => contiguous::sign::FLOAT,
|
||||
("usign", DType::BF16) => contiguous::sign::BFLOAT,
|
||||
("usign", DType::I64) => contiguous::sign::I64,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
} else {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => strided::cos::FLOAT,
|
||||
("usin", DType::F32) => strided::sin::FLOAT,
|
||||
("usqr", DType::F32) => strided::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => strided::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => strided::neg::FLOAT,
|
||||
("uexp", DType::F32) => strided::exp::FLOAT,
|
||||
("ulog", DType::F32) => strided::log::FLOAT,
|
||||
("ugelu", DType::F32) => strided::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => strided::erf::FLOAT,
|
||||
("usilu", DType::F32) => strided::silu::FLOAT,
|
||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
("usqrt", DType::F16) => strided::sqrt::HALF,
|
||||
("uneg", DType::F16) => strided::neg::HALF,
|
||||
("uexp", DType::F16) => strided::exp::HALF,
|
||||
("ulog", DType::F16) => strided::log::HALF,
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("usilu", DType::F16) => strided::silu::HALF,
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
("usqrt", DType::F16) => strided::sqrt::HALF,
|
||||
("uneg", DType::F16) => strided::neg::HALF,
|
||||
("uexp", DType::F16) => strided::exp::HALF,
|
||||
("ulog", DType::F16) => strided::log::HALF,
|
||||
("ugelu", DType::F16) => strided::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => strided::erf::HALF,
|
||||
("usilu", DType::F16) => strided::silu::HALF,
|
||||
("uabs", DType::F16) => strided::abs::HALF,
|
||||
("uceil", DType::F16) => strided::ceil::HALF,
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
|
||||
("ucos", DType::BF16) => strided::cos::BFLOAT,
|
||||
("usin", DType::BF16) => strided::sin::BFLOAT,
|
||||
("usqr", DType::BF16) => strided::sqr::BFLOAT,
|
||||
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
|
||||
("uneg", DType::BF16) => strided::neg::BFLOAT,
|
||||
("uexp", DType::BF16) => strided::exp::BFLOAT,
|
||||
("ulog", DType::BF16) => strided::log::BFLOAT,
|
||||
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
|
||||
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
|
||||
("uerf", DType::BF16) => strided::erf::BFLOAT,
|
||||
("usilu", DType::BF16) => strided::silu::BFLOAT,
|
||||
("uabs", DType::BF16) => strided::abs::BFLOAT,
|
||||
("uceil", DType::BF16) => strided::ceil::BFLOAT,
|
||||
("ufloor", DType::BF16) => strided::floor::BFLOAT,
|
||||
("urelu", DType::BF16) => strided::relu::BFLOAT,
|
||||
("uround", DType::BF16) => strided::round::BFLOAT,
|
||||
("utanh", DType::BF16) => strided::tanh::BFLOAT,
|
||||
("ucos", DType::BF16) => strided::cos::BFLOAT,
|
||||
("usin", DType::BF16) => strided::sin::BFLOAT,
|
||||
("usqr", DType::BF16) => strided::sqr::BFLOAT,
|
||||
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
|
||||
("uneg", DType::BF16) => strided::neg::BFLOAT,
|
||||
("uexp", DType::BF16) => strided::exp::BFLOAT,
|
||||
("ulog", DType::BF16) => strided::log::BFLOAT,
|
||||
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
|
||||
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
|
||||
("uerf", DType::BF16) => strided::erf::BFLOAT,
|
||||
("usilu", DType::BF16) => strided::silu::BFLOAT,
|
||||
("uabs", DType::BF16) => strided::abs::BFLOAT,
|
||||
("uceil", DType::BF16) => strided::ceil::BFLOAT,
|
||||
("ufloor", DType::BF16) => strided::floor::BFLOAT,
|
||||
("urelu", DType::BF16) => strided::relu::BFLOAT,
|
||||
("uround", DType::BF16) => strided::round::BFLOAT,
|
||||
("utanh", DType::BF16) => strided::tanh::BFLOAT,
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
src,
|
||||
layout.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
src,
|
||||
layout.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
|
||||
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||
}
|
||||
|
||||
@ -1784,19 +1702,6 @@ impl BackendDevice for MetalDevice {
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorageRef::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
};
|
||||
Ok(Self::Storage::new(buffer?, self.clone(), count, T::DTYPE))
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match storage {
|
||||
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
|
@ -330,7 +330,7 @@ impl Tensor {
|
||||
path: P,
|
||||
) -> Result<()> {
|
||||
let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
|
||||
let options: zip::write::FileOptions<()> =
|
||||
let options =
|
||||
zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
|
||||
|
||||
for (name, tensor) in ts.iter() {
|
||||
|
@ -2,7 +2,6 @@ use super::{GgmlDType, QStorage};
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
use half::f16;
|
||||
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||
|
||||
@ -60,7 +59,7 @@ fn quantize_q8_1(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dequantize_f32(
|
||||
fn dequantize(
|
||||
data: &CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
@ -70,27 +69,27 @@ fn dequantize_f32(
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0_f32",
|
||||
"dequantize_block_q5_0",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1_f32",
|
||||
"dequantize_block_q5_1",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f32", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f32", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f32", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f32", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f32", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f32", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
@ -117,63 +116,6 @@ fn dequantize_f32(
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_f16(
|
||||
data: &CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||
GgmlDType::Q5_0 => (
|
||||
"dequantize_block_q5_0_f16",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q5_1 => (
|
||||
"dequantize_block_q5_1_f16",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
ceil_div(elem_count, 2 * CUDA_DEQUANTIZE_BLOCK_SIZE),
|
||||
),
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0_f16", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K_f16", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K_f16", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K_f16", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K_f16", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K_f16", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let params = (data, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (data, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_mul_mat_vec(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &CudaView<f32>,
|
||||
@ -236,8 +178,8 @@ fn mul_mat_vec_via_q8_1(
|
||||
if y.len() != ncols * b_size {
|
||||
crate::bail!("unexpected y size {}, ncols {ncols} {nrows}", y.len())
|
||||
}
|
||||
if b_size == 0 || b_size > 8 {
|
||||
crate::bail!("only bsize between 1 and 8 are supported, got {b_size}")
|
||||
if b_size == 0 || b_size > 4 {
|
||||
crate::bail!("only bsize between 1 and 4 are supported, got {b_size}")
|
||||
}
|
||||
// Start by quantizing y
|
||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||
@ -262,16 +204,14 @@ fn mul_mat_vec_via_q8_1(
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||
let (nblocks, nwarps) = match b_size {
|
||||
1 => (nrows as u32, 4),
|
||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||
let nblocks = if b_size == 1 {
|
||||
nrows as u32
|
||||
} else {
|
||||
(nrows as u32 + 1) / 2
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (nblocks, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, nwarps, 1),
|
||||
block_dim: (WARP_SIZE as u32, 4, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
@ -399,7 +339,7 @@ impl QCudaStorage {
|
||||
| GgmlDType::Q8K
|
||||
);
|
||||
if fast_kernel {
|
||||
return dequantize_f32(&self.data, self.dtype, elem_count, self.device());
|
||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||
}
|
||||
// Run the dequantization on cpu.
|
||||
|
||||
@ -427,10 +367,6 @@ impl QCudaStorage {
|
||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
dequantize_f16(&self.data, self.dtype, elem_count, self.device())
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
@ -462,7 +398,7 @@ impl QCudaStorage {
|
||||
let max_bm = if FORCE_DMMV.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
1
|
||||
} else {
|
||||
8
|
||||
4
|
||||
};
|
||||
let use_vec_kernel = match layout.shape().dims() {
|
||||
[b, m, _k] => b * m <= max_bm,
|
||||
|
@ -24,10 +24,6 @@ impl QCudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -135,6 +135,7 @@ pub enum ValueType {
|
||||
// The value is a UTF-8 non-null-terminated string, with length prepended.
|
||||
String,
|
||||
// The value is an array of other values, with the length and type prepended.
|
||||
///
|
||||
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
|
||||
Array,
|
||||
}
|
||||
|
@ -152,9 +152,9 @@ impl QMetalStorage {
|
||||
// We always use a single batch dimension and stack all the tensors in the batch on the
|
||||
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
|
||||
// properly.
|
||||
let m = match dst_shape.len() {
|
||||
3 => dst_shape[0] * dst_shape[1],
|
||||
2 => dst_shape[0],
|
||||
let (b, m) = match dst_shape.len() {
|
||||
3 => (1, dst_shape[0] * dst_shape[1]),
|
||||
2 => (1, dst_shape[0]),
|
||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||
};
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
@ -166,23 +166,18 @@ impl QMetalStorage {
|
||||
let device = storage.device().clone();
|
||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
// In some cases it would be better to use the mm variant, though it has its drawbacks
|
||||
// around memory alignemnt.
|
||||
for batch_id in 0..m {
|
||||
candle_metal_kernels::call_quantized_matmul_mv_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(1, 1, n, k),
|
||||
storage.buffer(),
|
||||
(layout.start_offset() + batch_id * k) * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
batch_id * n * DType::F32.size_in_bytes(),
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
candle_metal_kernels::call_quantized_matmul_t(
|
||||
device.device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
self.dtype.into(),
|
||||
(b, m, n, k),
|
||||
storage.buffer(),
|
||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
&self.buffer,
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use crate::{CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
|
||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||
use k_quants::*;
|
||||
use std::borrow::Cow;
|
||||
|
||||
@ -360,24 +360,9 @@ impl QTensor {
|
||||
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
|
||||
let storage = self.storage.dequantize(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
|
||||
// In the CUDA case, we have a specialized kernel as this can be useful for volta
|
||||
// architectures. https://github.com/huggingface/candle/issues/2136
|
||||
match &self.storage {
|
||||
QStorage::Cuda(s) => {
|
||||
let s = s.dequantize_f16(self.shape.elem_count())?;
|
||||
let none = crate::op::BackpropOp::none();
|
||||
crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
|
||||
.to_device(device)
|
||||
}
|
||||
_ => {
|
||||
let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
let is_variable = false;
|
||||
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable)
|
||||
.to_device(device)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
@ -393,7 +378,6 @@ impl QTensor {
|
||||
pub enum QMatMul {
|
||||
QTensor(std::sync::Arc<QTensor>),
|
||||
Tensor(Tensor),
|
||||
TensorF16(Tensor),
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
@ -407,17 +391,6 @@ thread_local! {
|
||||
}
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static DEQUANTIZE_ALL_F16: bool = {
|
||||
match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
|
||||
Ok(s) => {
|
||||
!s.is_empty() && s != "0"
|
||||
},
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
|
||||
let dequantize = match qtensor.dtype() {
|
||||
@ -427,9 +400,6 @@ impl QMatMul {
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||
Self::Tensor(tensor)
|
||||
} else if DEQUANTIZE_ALL_F16.with(|b| *b) {
|
||||
let tensor = qtensor.dequantize_f16(&qtensor.device())?;
|
||||
Self::TensorF16(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
};
|
||||
@ -439,25 +409,6 @@ impl QMatMul {
|
||||
pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
Self::from_arc(std::sync::Arc::new(qtensor))
|
||||
}
|
||||
|
||||
pub fn dequantize_f16(&self) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::QTensor(t) => t.dequantize_f16(&t.device()),
|
||||
Self::Tensor(t) => t.to_dtype(DType::F16),
|
||||
Self::TensorF16(t) => Ok(t.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let w = self.dequantize_f16()?;
|
||||
let in_dtype = xs.dtype();
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for QTensor {
|
||||
@ -535,15 +486,6 @@ impl crate::Module for QMatMul {
|
||||
};
|
||||
xs.matmul(&w)
|
||||
}
|
||||
Self::TensorF16(w) => {
|
||||
let in_dtype = xs.dtype();
|
||||
let w = match *xs.dims() {
|
||||
[b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
|
||||
[bsize, _, _] => w.broadcast_left(bsize)?.t()?,
|
||||
_ => w.t()?,
|
||||
};
|
||||
xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,239 +0,0 @@
|
||||
use crate::{Result, Tensor};
|
||||
use rayon::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct ArgSort {
|
||||
asc: bool,
|
||||
last_dim: usize,
|
||||
}
|
||||
|
||||
impl ArgSort {
|
||||
fn asort<T: crate::WithDType>(&self, vs: &[T], layout: &crate::Layout) -> Vec<u32> {
|
||||
#[allow(clippy::uninit_vec)]
|
||||
// Safety: indexes are set later in the parallelized section.
|
||||
let mut sort_indexes = unsafe {
|
||||
let el_count = layout.shape().elem_count();
|
||||
let mut v = Vec::with_capacity(el_count);
|
||||
v.set_len(el_count);
|
||||
v
|
||||
};
|
||||
if self.asc {
|
||||
sort_indexes
|
||||
.par_chunks_exact_mut(self.last_dim)
|
||||
.zip(vs.par_chunks_exact(self.last_dim))
|
||||
.for_each(|(indexes, vs)| {
|
||||
indexes
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i as u32);
|
||||
indexes.sort_by(|&i, &j| {
|
||||
vs[i as usize]
|
||||
.partial_cmp(&vs[j as usize])
|
||||
.unwrap_or(std::cmp::Ordering::Greater)
|
||||
})
|
||||
});
|
||||
} else {
|
||||
sort_indexes
|
||||
.par_chunks_exact_mut(self.last_dim)
|
||||
.zip(vs.par_chunks_exact(self.last_dim))
|
||||
.for_each(|(indexes, vs)| {
|
||||
indexes
|
||||
.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, v)| *v = i as u32);
|
||||
indexes.sort_by(|&j, &i| {
|
||||
vs[i as usize]
|
||||
.partial_cmp(&vs[j as usize])
|
||||
.unwrap_or(std::cmp::Ordering::Greater)
|
||||
})
|
||||
});
|
||||
}
|
||||
sort_indexes
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::CustomOp1 for ArgSort {
|
||||
fn name(&self) -> &'static str {
|
||||
"argsort"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
storage: &crate::CpuStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CpuStorage, crate::Shape)> {
|
||||
let sort_indexes = match storage {
|
||||
crate::CpuStorage::U8(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::U32(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::I64(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::BF16(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F16(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F32(vs) => self.asort(vs, layout),
|
||||
crate::CpuStorage::F64(vs) => self.asort(vs, layout),
|
||||
};
|
||||
let sort_indexes = crate::CpuStorage::U32(sort_indexes);
|
||||
Ok((sort_indexes, layout.shape().into()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &crate::CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CudaStorage, crate::Shape)> {
|
||||
use crate::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, Map1Any, WrapErr};
|
||||
use crate::{CudaDevice, WithDType};
|
||||
|
||||
impl Map1Any for ArgSort {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &crate::Layout,
|
||||
_wrap: W,
|
||||
) -> Result<S> {
|
||||
let slice = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = if self.asc {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||
} else {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||
};
|
||||
let ncols = self.last_dim;
|
||||
let nrows = elem_count / ncols;
|
||||
let ncols_pad = next_power_of_2(ncols);
|
||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (1, nrows as u32, 1),
|
||||
block_dim: (ncols_pad as u32, 1, 1),
|
||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(S::U32(dst))
|
||||
}
|
||||
}
|
||||
|
||||
use crate::backend::BackendStorage;
|
||||
let dev = storage.device();
|
||||
let slice = self.map(&storage.slice, dev, layout)?;
|
||||
let dst = crate::cuda_backend::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, crate::Shape)> {
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::DType;
|
||||
|
||||
let name = {
|
||||
if self.asc {
|
||||
match storage.dtype() {
|
||||
DType::BF16 => "asort_asc_bf16",
|
||||
DType::F16 => "asort_asc_f16",
|
||||
DType::F32 => "asort_asc_f32",
|
||||
DType::F64 => "asort_asc_f64",
|
||||
DType::U8 => "asort_asc_u8",
|
||||
DType::U32 => "asort_asc_u32",
|
||||
DType::I64 => "asort_asc_i64",
|
||||
}
|
||||
} else {
|
||||
match storage.dtype() {
|
||||
DType::BF16 => "asort_desc_bf16",
|
||||
DType::F16 => "asort_desc_f16",
|
||||
DType::F32 => "asort_desc_f32",
|
||||
DType::F64 => "asort_desc_f64",
|
||||
DType::U8 => "asort_desc_u8",
|
||||
DType::U32 => "asort_desc_u32",
|
||||
DType::I64 => "asort_desc_i64",
|
||||
}
|
||||
}
|
||||
};
|
||||
let device = storage.device();
|
||||
let kernels = device.kernels();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
let el = layout.shape().elem_count();
|
||||
let ncols = self.last_dim;
|
||||
let nrows = el / ncols;
|
||||
let src = crate::metal_backend::buffer_o(storage.buffer(), layout, storage.dtype());
|
||||
let dst = device.new_buffer(el, DType::U32, "asort")?;
|
||||
let mut ncols_pad = 1;
|
||||
while ncols_pad < ncols {
|
||||
ncols_pad *= 2;
|
||||
}
|
||||
candle_metal_kernels::call_arg_sort(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
kernels,
|
||||
name,
|
||||
nrows,
|
||||
ncols,
|
||||
ncols_pad,
|
||||
src,
|
||||
&dst,
|
||||
)
|
||||
.map_err(crate::Error::wrap)?;
|
||||
let dst = crate::MetalStorage::new(dst, device.clone(), el, DType::U32);
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn next_power_of_2(x: usize) -> usize {
|
||||
let mut n = 1;
|
||||
while n < x {
|
||||
n *= 2
|
||||
}
|
||||
n
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
/// Returns the indices that sort the tensor along the last dimension.
|
||||
///
|
||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||
/// comes to ties.
|
||||
pub fn arg_sort_last_dim(&self, asc: bool) -> Result<Tensor> {
|
||||
if !self.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous {
|
||||
op: "arg_sort_last_dim",
|
||||
});
|
||||
}
|
||||
let last_dim = match self.dims().last() {
|
||||
None => crate::bail!("empty last-dim in arg-sort"),
|
||||
Some(last_dim) => *last_dim,
|
||||
};
|
||||
// No need for a backward pass for arg sort.
|
||||
self.apply_op1_no_bwd(&ArgSort { asc, last_dim })
|
||||
}
|
||||
|
||||
/// Sorts the tensor along the last dimension, returns the sorted tensor together with the
|
||||
/// sorted indexes.
|
||||
///
|
||||
/// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in
|
||||
/// descending order. The sort is unstable so there is no guarantees on the final order when it
|
||||
/// comes to ties.
|
||||
pub fn sort_last_dim(&self, asc: bool) -> Result<(Tensor, Tensor)> {
|
||||
if !self.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous {
|
||||
op: "sort_last_dim",
|
||||
});
|
||||
}
|
||||
let asort = self.arg_sort_last_dim(asc)?;
|
||||
let sorted = self.gather(&asort, crate::D::Minus1)?;
|
||||
Ok((sorted, asort))
|
||||
}
|
||||
}
|
@ -456,15 +456,7 @@ impl Tensor {
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.len();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let storage = device.storage_from_slice(array)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
Self::new_impl(array, shape.into(), device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
|
@ -34,14 +34,9 @@ impl Var {
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
// Convert a tensor to a variable, if the tensor is already a variable then it is returned as is.
|
||||
pub fn from_tensor(t: &Tensor) -> Result<Self> {
|
||||
if t.is_variable() {
|
||||
Ok(Self(t.clone()))
|
||||
} else {
|
||||
let inner = t.make_var()?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
let inner = t.make_var()?;
|
||||
Ok(Self(inner))
|
||||
}
|
||||
|
||||
pub fn rand_f64<S: Into<Shape>>(
|
||||
|
@ -3,7 +3,7 @@ use candle_core::{
|
||||
quantized::{self, GgmlDType},
|
||||
test_device,
|
||||
test_utils::to_vec2_round,
|
||||
DType, Device, IndexOp, Module, Result, Tensor,
|
||||
Device, IndexOp, Module, Result, Tensor,
|
||||
};
|
||||
use quantized::{k_quants, GgmlType};
|
||||
use rand::prelude::*;
|
||||
@ -193,25 +193,17 @@ fn qmm_batch(dev: &Device) -> Result<()> {
|
||||
let mm3 = rhs.forward(&lhs3)?;
|
||||
assert_eq!(mm3.shape().dims(), [6, 6]);
|
||||
let diff3 = (mm3.i(2..4)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff3, 0.0);
|
||||
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff3, 0.0);
|
||||
let lhs4 = Tensor::cat(&[&lhs3, &lhs3], 0)?;
|
||||
let mm4 = rhs.forward(&lhs4)?;
|
||||
assert_eq!(mm4.shape().dims(), [12, 6]);
|
||||
let diff4 = (mm4.i(..6)? - &mm3)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
// We use a different kernel for sizes from 1 to 8 on cuda which explains
|
||||
// the difference here.
|
||||
assert!(0. < diff4 && diff4 < 1e-4)
|
||||
assert!(diff3 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff4, 0.0)
|
||||
assert_eq!(diff3, 0.0)
|
||||
};
|
||||
let diff3 = (mm3.i(4..)? - &mm)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
if dev.is_cuda() {
|
||||
assert!(diff3 < 1e-4)
|
||||
} else {
|
||||
assert_eq!(diff3, 0.0)
|
||||
};
|
||||
let diff4 = (mm4.i(6..)? - &mm4.i(..6)?)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff4, 0.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -225,13 +217,6 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
dst.to_vec1::<f32>()?,
|
||||
&[
|
||||
@ -258,13 +243,6 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
@ -291,13 +269,6 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
@ -324,13 +295,6 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
assert_eq!(
|
||||
round_vector(&dst.to_vec1::<f32>()?),
|
||||
&[
|
||||
@ -415,13 +379,6 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
||||
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
|
||||
if error > max_error {
|
||||
bail!(
|
||||
@ -439,13 +396,6 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -465,13 +415,6 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -486,13 +429,6 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -512,13 +448,6 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -533,13 +462,6 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -559,13 +481,6 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -580,13 +495,6 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -606,13 +514,6 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -627,13 +528,6 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -653,13 +547,6 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
@ -674,13 +561,6 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
let dst = quant.dequantize(device)?;
|
||||
let dst_f16 = quant.dequantize_f16(device)?;
|
||||
let diff = (dst.to_dtype(DType::F16)? - dst_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src = src.to_vec1::<f32>()?;
|
||||
let dst = dst.to_vec1::<f32>()?;
|
||||
@ -700,13 +580,6 @@ fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
let src_big = get_test_vector2(128.0, 1024, device)?;
|
||||
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?;
|
||||
let dst_big = quant_big.dequantize(device)?;
|
||||
let dst_big_f16 = quant_big.dequantize_f16(device)?;
|
||||
let diff = (dst_big.to_dtype(DType::F16)? - dst_big_f16)?
|
||||
.to_dtype(DType::F32)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
|
||||
let src_big = src_big.to_vec1::<f32>()?;
|
||||
let dst_big = dst_big.to_vec1::<f32>()?;
|
||||
|
@ -96,40 +96,6 @@ fn clamp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn asort(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let indexes = tensor.arg_sort_last_dim(true)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||
);
|
||||
let indexes = tensor.arg_sort_last_dim(false)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
let (sorted, indexes) = tensor.sort_last_dim(true)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||
);
|
||||
assert_eq!(
|
||||
sorted.to_vec2::<f32>()?,
|
||||
[[1.0, 1.1, 3.0, 4.0, 5.0], [1.0, 2.0, 2.1, 7.0, 8.0]]
|
||||
);
|
||||
let (sorted, indexes) = tensor.sort_last_dim(false)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
assert_eq!(
|
||||
sorted.to_vec2::<f32>()?,
|
||||
[[5.0, 4.0, 3.0, 1.1, 1.0], [8.0, 7.0, 2.1, 2.0, 1.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -1185,7 +1151,6 @@ test_device!(
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||
|
||||
|
@ -17,7 +17,7 @@ use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
@ -31,8 +31,6 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
enum Which {
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V3Instruct,
|
||||
#[value(name = "solar-10.7b")]
|
||||
Solar10_7B,
|
||||
#[value(name = "tiny-llama-1.1b-chat")]
|
||||
@ -47,23 +45,19 @@ struct Args {
|
||||
cpu: bool,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 10000)]
|
||||
#[arg(long, default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// Disable the key-value cache.
|
||||
@ -89,18 +83,18 @@ struct Args {
|
||||
revision: Option<String>,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "v3")]
|
||||
#[arg(long, default_value = "v2")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 128)]
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
@ -126,13 +120,11 @@ fn main() -> Result<()> {
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, mut cache, config) = {
|
||||
let (llama, tokenizer_filename, mut cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||
});
|
||||
@ -146,7 +138,7 @@ fn main() -> Result<()> {
|
||||
let config = config.into_config(args.use_flash_attn);
|
||||
|
||||
let filenames = match args.which {
|
||||
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
|
||||
Which::V1 | Which::V2 | Which::Solar10_7B => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||
@ -154,12 +146,10 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = config
|
||||
.eos_token_id
|
||||
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||
let mut tokens = tokenizer
|
||||
.encode(prompt, true)
|
||||
@ -170,22 +160,8 @@ fn main() -> Result<()> {
|
||||
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let mut start_gen = std::time::Instant::now();
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
let mut token_generated = 0;
|
||||
for index in 0..args.sample_len {
|
||||
@ -194,9 +170,6 @@ fn main() -> Result<()> {
|
||||
} else {
|
||||
(tokens.len(), 0)
|
||||
};
|
||||
if index == 1 {
|
||||
start_gen = std::time::Instant::now()
|
||||
}
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, context_index, &mut cache)?;
|
||||
@ -232,7 +205,7 @@ fn main() -> Result<()> {
|
||||
println!(
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
token_generated,
|
||||
(token_generated - 1) as f64 / dt.as_secs_f64(),
|
||||
token_generated as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -10,7 +10,7 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{bail, Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
@ -24,15 +24,57 @@ mod model;
|
||||
use model::{Config, Llama};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
const DEFAULT_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
Or whether he be 'scaped away or no
|
||||
From Clifford's and Northumberland's pursuit:
|
||||
Had he been ta'en, we should have heard the news;
|
||||
Had he been slain, we should have heard the news;
|
||||
Or had he 'scaped, methinks we should have heard
|
||||
The happy tidings of his good escape.
|
||||
How fares my brother? why is he so sad?
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
V2_7b,
|
||||
V2_70b,
|
||||
V3_8b,
|
||||
V3_70b,
|
||||
}
|
||||
RICHARD:
|
||||
I cannot joy, until I be resolved
|
||||
Where our right valiant father is become.
|
||||
I saw him in the battle range about;
|
||||
And watch'd him how he singled Clifford forth.
|
||||
Methought he bore him in the thickest troop
|
||||
As doth a lion in a herd of neat;
|
||||
Or as a bear, encompass'd round with dogs,
|
||||
Who having pinch'd a few and made them cry,
|
||||
The rest stand all aloof, and bark at him.
|
||||
So fared our father with his enemies;
|
||||
So fled his enemies my warlike father:
|
||||
Methinks, 'tis prize enough to be his son.
|
||||
See how the morning opes her golden gates,
|
||||
And takes her farewell of the glorious sun!
|
||||
How well resembles it the prime of youth,
|
||||
Trimm'd like a younker prancing to his love!
|
||||
|
||||
EDWARD:
|
||||
Dazzle mine eyes, or do I see three suns?
|
||||
|
||||
RICHARD:
|
||||
Three glorious suns, each one a perfect sun;
|
||||
Not separated with the racking clouds,
|
||||
But sever'd in a pale clear-shining sky.
|
||||
See, see! they join, embrace, and seem to kiss,
|
||||
As if they vow'd some league inviolable:
|
||||
Now are they but one lamp, one light, one sun.
|
||||
In this the heaven figures some event.
|
||||
|
||||
EDWARD:
|
||||
'Tis wondrous strange, the like yet never heard of.
|
||||
I think it cites us, brother, to the field,
|
||||
That we, the sons of brave Plantagenet,
|
||||
Each one already blazing by our meeds,
|
||||
Should notwithstanding join our lights together
|
||||
And over-shine the earth as this the world.
|
||||
Whate'er it bodes, henceforward will I bear
|
||||
Upon my target three fair-shining suns.
|
||||
";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@ -44,8 +86,8 @@ struct Args {
|
||||
rank: Option<usize>,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
@ -75,12 +117,6 @@ struct Args {
|
||||
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "v3-8b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long, default_value = "nccl_id.txt")]
|
||||
comm_file: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -93,27 +129,14 @@ fn main() -> Result<()> {
|
||||
Some("bf16") => DType::BF16,
|
||||
Some("f32") => DType::F32,
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => match args.which {
|
||||
Which::V2_7b | Which::V2_70b => DType::F16,
|
||||
Which::V3_8b | Which::V3_70b => DType::BF16,
|
||||
},
|
||||
None => DType::F16,
|
||||
};
|
||||
|
||||
let comm_file = std::path::PathBuf::from(&args.comm_file);
|
||||
if comm_file.exists() {
|
||||
bail!("comm file {comm_file:?} already exists, please remove it first")
|
||||
}
|
||||
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model) => model,
|
||||
None => match args.which {
|
||||
Which::V2_7b => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V2_70b => "meta-llama/Llama-2-70b-hf".to_string(),
|
||||
Which::V3_8b => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3_70b => "meta-llama/Meta-Llama-3-70B".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let model_id = args
|
||||
.model_id
|
||||
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
@ -122,40 +145,39 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
|
||||
|
||||
let rank = match args.rank {
|
||||
None => {
|
||||
println!("creating {} child processes", args.num_shards);
|
||||
let children: Vec<_> = (0..args.num_shards)
|
||||
.map(|rank| {
|
||||
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
||||
args.push_back("--rank".to_string());
|
||||
args.push_back(format!("{rank}"));
|
||||
let name = args.pop_front().unwrap();
|
||||
std::process::Command::new(name).args(args).spawn().unwrap()
|
||||
})
|
||||
.collect();
|
||||
for mut child in children {
|
||||
child.wait()?;
|
||||
}
|
||||
return Ok(());
|
||||
if args.rank.is_none() {
|
||||
let children: Vec<_> = (0..args.num_shards)
|
||||
.map(|rank| {
|
||||
let mut args: std::collections::VecDeque<_> = std::env::args().collect();
|
||||
args.push_back("--rank".to_string());
|
||||
args.push_back(format!("{rank}"));
|
||||
let name = args.pop_front().unwrap();
|
||||
std::process::Command::new(name).args(args).spawn().unwrap()
|
||||
})
|
||||
.collect();
|
||||
for mut child in children {
|
||||
child.wait().unwrap();
|
||||
}
|
||||
Some(rank) => rank,
|
||||
};
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let i = args.rank.unwrap();
|
||||
let num_shards = args.num_shards;
|
||||
let rank = i;
|
||||
// Primitive IPC
|
||||
let id = if rank == 0 {
|
||||
let id = Id::new().unwrap();
|
||||
let tmp_file = comm_file.with_extension(".comm.tgz");
|
||||
std::fs::File::create(&tmp_file)?
|
||||
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
|
||||
std::fs::rename(&tmp_file, &comm_file)?;
|
||||
std::fs::File::create("nccl_id.txt.tmp")?
|
||||
.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())
|
||||
.unwrap();
|
||||
std::fs::rename("nccl_id.txt.tmp", "nccl_id.txt")?;
|
||||
id
|
||||
} else {
|
||||
while !comm_file.exists() {
|
||||
let path = std::path::PathBuf::from("nccl_id.txt");
|
||||
while !path.exists() {
|
||||
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||
}
|
||||
let data = std::fs::read(&comm_file)?;
|
||||
let data = std::fs::read("nccl_id.txt")?;
|
||||
let internal: [i8; 128] = data
|
||||
.into_iter()
|
||||
.map(|i| i as i8)
|
||||
@ -165,17 +187,14 @@ fn main() -> Result<()> {
|
||||
let id: Id = Id::uninit(internal);
|
||||
id
|
||||
};
|
||||
let device = CudaDevice::new(rank)?;
|
||||
let comm = match Comm::from_rank(device, rank, num_shards, id) {
|
||||
Ok(comm) => Rc::new(comm),
|
||||
Err(err) => anyhow::bail!("nccl error {:?}", err.0),
|
||||
};
|
||||
let device = CudaDevice::new(i)?;
|
||||
let comm = Rc::new(Comm::from_rank(device, i, num_shards, id).unwrap());
|
||||
if rank == 0 {
|
||||
std::fs::remove_file(comm_file)?;
|
||||
std::fs::remove_file("nccl_id.txt")?;
|
||||
}
|
||||
println!("Rank {rank:?} spawned");
|
||||
|
||||
let device = Device::new_cuda(rank)?;
|
||||
let device = Device::new_cuda(i)?;
|
||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||
|
||||
println!("building the model");
|
||||
@ -191,24 +210,14 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
|
||||
|
||||
println!("starting the inference loop");
|
||||
let temperature = if args.temperature <= 0. {
|
||||
None
|
||||
} else {
|
||||
Some(args.temperature)
|
||||
};
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
|
||||
let mut new_tokens = vec![];
|
||||
let mut start_gen = std::time::Instant::now();
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut index_pos = 0;
|
||||
for index in 0..args.sample_len {
|
||||
// Only start timing at the second token as processing the first token waits for all the
|
||||
// weights to be loaded in an async way.
|
||||
if index == 1 {
|
||||
start_gen = std::time::Instant::now()
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
@ -219,23 +228,25 @@ fn main() -> Result<()> {
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
if Some(next_token) == config.eos_token_id {
|
||||
break;
|
||||
}
|
||||
if rank == 0 {
|
||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
println!(
|
||||
"{} token: {} '{}'",
|
||||
index + 1,
|
||||
next_token,
|
||||
tokenizer.decode(&[next_token], true).map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
}
|
||||
println!();
|
||||
let dt = start_gen.elapsed();
|
||||
if rank == 0 {
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"\n\n{} tokens generated ({} token/s)\n",
|
||||
"{} tokens generated ({} token/s)\n----\n{}\n----",
|
||||
args.sample_len,
|
||||
(args.sample_len - 1) as f64 / dt.as_secs_f64(),
|
||||
args.sample_len as f64 / dt.as_secs_f64(),
|
||||
tokenizer
|
||||
.decode(new_tokens.as_slice(), true)
|
||||
.map_err(E::msg)?
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
|
@ -1,14 +1,15 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use serde::Deserialize;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
pub type Config = candle_transformers::models::llama::LlamaConfig;
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
@ -25,7 +26,7 @@ impl TensorParallelColumnLinear {
|
||||
|
||||
struct TensorParallelRowLinear {
|
||||
linear: Linear,
|
||||
all_reduce: AllReduce,
|
||||
comm: Rc<Comm>,
|
||||
}
|
||||
|
||||
struct AllReduce {
|
||||
@ -35,6 +36,8 @@ struct AllReduce {
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Sync for AllReduce {}
|
||||
/// This is actually not safe: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/threadsafety.html
|
||||
/// But for this example purposes, this will work
|
||||
unsafe impl Send for AllReduce {}
|
||||
|
||||
impl CustomOp1 for AllReduce {
|
||||
@ -43,7 +46,7 @@ impl CustomOp1 for AllReduce {
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, _s: &CpuStorage, _l: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
candle::bail!("AllReduce is never used on cpu")
|
||||
todo!("implement allreduce for cpu is not necessary for single node");
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
@ -53,49 +56,31 @@ impl CustomOp1 for AllReduce {
|
||||
l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::WrapErr;
|
||||
use cudarc::driver::DeviceSlice;
|
||||
use half::{bf16, f16};
|
||||
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dev = s.device().clone();
|
||||
let dst = match s.dtype() {
|
||||
DType::BF16 => {
|
||||
let s = s.as_cuda_slice::<bf16>()?;
|
||||
let s = match l.contiguous_offsets() {
|
||||
Some((0, l)) if l == s.len() => s,
|
||||
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||
};
|
||||
let mut dst = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||
self.comm
|
||||
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(candle::Error::debug)?;
|
||||
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
DType::F16 => {
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
let s = match l.contiguous_offsets() {
|
||||
Some((0, l)) if l == s.len() => s,
|
||||
Some(_) | None => candle::bail!("input has to be contiguous"),
|
||||
};
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm
|
||||
.all_reduce(s, &mut dst, &ReduceOp::Sum)
|
||||
.map_err(candle::Error::debug)?;
|
||||
candle::CudaStorage::wrap_cuda_slice(dst, dev)
|
||||
}
|
||||
dtype => candle::bail!("unsupported dtype {dtype:?}"),
|
||||
};
|
||||
let s = s.as_cuda_slice::<f16>()?;
|
||||
// let s = match l.contiguous_offsets() {
|
||||
// None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
// Some((o1, o2)) => s.slice(o1..o2),
|
||||
// };
|
||||
let mut dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||
self.comm.all_reduce(s, &mut dst, &ReduceOp::Sum).unwrap();
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||
x.apply_op1(AllReduce { comm: comm.clone() })
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
||||
let all_reduce = AllReduce { comm };
|
||||
Self { linear, all_reduce }
|
||||
Self { linear, comm }
|
||||
}
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
self.linear.forward(x)?.apply_op1_no_bwd(&self.all_reduce)
|
||||
let x = self.linear.forward(x)?;
|
||||
all_reduce_sum(&x, &self.comm)
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,6 +121,23 @@ impl TensorParallelRowLinear {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
#[allow(clippy::type_complexity)]
|
||||
@ -159,6 +161,7 @@ impl Cache {
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
@ -194,10 +197,16 @@ struct CausalSelfAttention {
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, _, seq_len, _hidden_size) = x.shape().dims4()?;
|
||||
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
candle_nn::rotary_emb::rope(x, &cos, &sin)
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
@ -223,16 +232,13 @@ impl CausalSelfAttention {
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
@ -263,14 +269,25 @@ impl CausalSelfAttention {
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||
.reshape((b_sz, seq_len, hidden_size))?;
|
||||
.transpose(1, 2)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||
candle_transformers::utils::repeat_kv(x, n_rep)
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
||||
@ -284,7 +301,7 @@ impl CausalSelfAttention {
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
||||
num_key_value_heads: cfg.num_key_value_heads() / comm.world_size(),
|
||||
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
cache: cache.clone(),
|
||||
})
|
||||
@ -298,6 +315,18 @@ struct Mlp {
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(
|
||||
c_fc1: TensorParallelColumnLinear,
|
||||
c_fc2: TensorParallelColumnLinear,
|
||||
c_proj: TensorParallelRowLinear,
|
||||
) -> Self {
|
||||
Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
@ -307,11 +336,7 @@ impl Mlp {
|
||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm)?;
|
||||
Ok(Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
})
|
||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||
}
|
||||
}
|
||||
|
||||
@ -402,8 +427,10 @@ impl Llama {
|
||||
cfg,
|
||||
comm.clone(),
|
||||
)
|
||||
.unwrap()
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
.collect();
|
||||
|
||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||
}
|
||||
}
|
||||
|
@ -1,36 +0,0 @@
|
||||
# candle-olmo: Open Language Models designed to enable the science of language models
|
||||
|
||||
OLMo is a series of Open Language Models designed to enable the science of language models.
|
||||
|
||||
- **Project Page:** https://allenai.org/olmo
|
||||
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
|
||||
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
|
||||
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
|
||||
<!-- - **Press release:** TODO -->
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example olmo --release -- --prompt "It is only with the heart that one can see rightly"
|
||||
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 354.977µs
|
||||
loaded the model in 19.87779666s
|
||||
It is only with the heart that one can see rightly; what is essential is invisible to the eye.
|
||||
```
|
||||
|
||||
Various model sizes are available via the `--model` argument.
|
||||
|
||||
```bash
|
||||
$ cargo run --example olmo --release -- --model 1.7-7b --prompt 'It is only with the heart that one can see rightly'
|
||||
|
||||
avx: true, neon: false, simd128: false, f16c: true
|
||||
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64
|
||||
retrieved the files in 1.226087ms
|
||||
loaded the model in 171.274578609s
|
||||
It is only with the heart that one can see rightly; what is essential is invisible to the eye.”
|
||||
~ Antoine de Saint-Exupery, The Little Prince
|
||||
I am a big fan of this quote. It reminds me that I need to be open and aware of my surroundings in order to truly appreciate them.
|
||||
```
|
||||
|
@ -1,284 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::olmo::{Config, Model as OLMo};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
OLMo(OLMo),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, false)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = match &mut self.model {
|
||||
Model::OLMo(m) => m.forward(&input, start_pos)?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
||||
enum Which {
|
||||
#[value(name = "1b")]
|
||||
W1b,
|
||||
#[value(name = "7b")]
|
||||
W7b,
|
||||
#[value(name = "7b-twin-2t")]
|
||||
W7bTwin2T,
|
||||
#[value(name = "1.7-7b")]
|
||||
V1_7W7b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long, default_value = "1b")]
|
||||
model: Which,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => match args.model {
|
||||
Which::W1b => "allenai/OLMo-1B-hf".to_string(),
|
||||
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
|
||||
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
|
||||
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => match args.model {
|
||||
Which::W1b => {
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
},
|
||||
};
|
||||
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
config
|
||||
};
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = {
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = OLMo::new(&config, vb)?;
|
||||
Model::OLMo(model)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,9 +1,8 @@
|
||||
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
|
||||
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5),
|
||||
[Phi-2](https://huggingface.co/microsoft/phi-2), and
|
||||
[Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) are language models using
|
||||
only 1.3, 2.7, and 3.8 billion parameters but with state of the art performance compared to
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
|
||||
[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
|
||||
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
|
||||
models with up to 10 billion parameters.
|
||||
|
||||
The candle implementation provides both the standard version as well as a
|
||||
|
@ -7,13 +7,11 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||
use candle_transformers::models::phi::{Config as PhiConfig, Model as Phi};
|
||||
use candle_transformers::models::phi3::{Config as Phi3Config, Model as Phi3};
|
||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
@ -22,14 +20,13 @@ use tokenizers::Tokenizer;
|
||||
enum Model {
|
||||
MixFormer(MixFormer),
|
||||
Phi(Phi),
|
||||
Phi3(Phi3),
|
||||
Quantized(QMixFormer),
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
@ -52,7 +49,7 @@ impl TextGeneration {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
@ -64,11 +61,7 @@ impl TextGeneration {
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
let tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||
if tokens.is_empty() {
|
||||
anyhow::bail!("Empty prompts are not supported in the phi model.")
|
||||
}
|
||||
@ -80,14 +73,13 @@ impl TextGeneration {
|
||||
}
|
||||
let mut tokens = tokens.get_ids().to_vec();
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
|
||||
Some(token) => *token,
|
||||
None => anyhow::bail!("cannot find the endoftext token"),
|
||||
};
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let start_gen = std::time::Instant::now();
|
||||
let mut pos = 0;
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
@ -96,7 +88,6 @@ impl TextGeneration {
|
||||
Model::MixFormer(m) => m.forward(&input)?,
|
||||
Model::Phi(m) => m.forward(&input)?,
|
||||
Model::Quantized(m) => m.forward(&input)?,
|
||||
Model::Phi3(m) => m.forward(&input, pos)?.i((.., 0, ..))?,
|
||||
};
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
@ -116,11 +107,9 @@ impl TextGeneration {
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
pos += context_size;
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
@ -139,8 +128,6 @@ enum WhichModel {
|
||||
V1_5,
|
||||
#[value(name = "2")]
|
||||
V2,
|
||||
#[value(name = "3")]
|
||||
V3,
|
||||
#[value(name = "2-old")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
@ -209,10 +196,6 @@ struct Args {
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The dtype to be used for running the model, e.g. f32, bf16, or f16.
|
||||
#[arg(long)]
|
||||
dtype: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -253,7 +236,6 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
@ -271,10 +253,9 @@ fn main() -> Result<()> {
|
||||
WhichModel::V1 => "refs/pr/8".to_string(),
|
||||
WhichModel::V1_5 => "refs/pr/73".to_string(),
|
||||
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||
WhichModel::V2
|
||||
| WhichModel::V3
|
||||
| WhichModel::PuffinPhiV2
|
||||
| WhichModel::PhiHermes => "main".to_string(),
|
||||
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"main".to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -283,11 +264,9 @@ fn main() -> Result<()> {
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => match args.model {
|
||||
WhichModel::V1
|
||||
| WhichModel::V1_5
|
||||
| WhichModel::V2
|
||||
| WhichModel::V2Old
|
||||
| WhichModel::V3 => repo.get("tokenizer.json")?,
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 | WhichModel::V2Old => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
@ -303,19 +282,14 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||
WhichModel::V3 => anyhow::bail!(
|
||||
"use the quantized or quantized-phi examples for quantized phi-v3"
|
||||
),
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => {
|
||||
candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?
|
||||
}
|
||||
WhichModel::V2 | WhichModel::V2Old => candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?,
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
||||
}
|
||||
@ -332,9 +306,6 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
WhichModel::V3 => {
|
||||
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
|
||||
}
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
@ -349,17 +320,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
Model::Quantized(model)
|
||||
} else {
|
||||
let dtype = match args.dtype {
|
||||
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
||||
None => {
|
||||
if args.model == WhichModel::V3 && device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
}
|
||||
}
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
@ -368,13 +329,6 @@ fn main() -> Result<()> {
|
||||
let phi = Phi::new(&config, vb)?;
|
||||
Model::Phi(phi)
|
||||
}
|
||||
WhichModel::V3 => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Phi3Config = serde_json::from_str(&config)?;
|
||||
let phi3 = Phi3::new(&config, vb)?;
|
||||
Model::Phi3(phi3)
|
||||
}
|
||||
WhichModel::V2Old => {
|
||||
let config = config();
|
||||
Model::MixFormer(MixFormer::new_v2(&config, vb)?)
|
||||
@ -467,10 +421,6 @@ fn mmlu<P: AsRef<std::path::Path>>(
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
}
|
||||
Model::Phi3(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input, 0)?
|
||||
}
|
||||
Model::Quantized(m) => {
|
||||
m.clear_kv_cache();
|
||||
m.forward(&input)?
|
||||
|
@ -1,317 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_llama::ModelWeights as Phi3b;
|
||||
use candle_transformers::models::quantized_phi::ModelWeights as Phi2;
|
||||
use candle_transformers::models::quantized_phi3::ModelWeights as Phi3;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "phi-2")]
|
||||
Phi2,
|
||||
#[value(name = "phi-3")]
|
||||
Phi3,
|
||||
/// Alternative implementation of phi-3, based on llama.
|
||||
#[value(name = "phi-3b")]
|
||||
Phi3b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "phi-3b")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = match self.which {
|
||||
Which::Phi2 => "microsoft/phi-2",
|
||||
Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct",
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename, revision) = match self.which {
|
||||
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", "main"),
|
||||
Which::Phi3 => (
|
||||
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||
"main",
|
||||
),
|
||||
Which::Phi3b => (
|
||||
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||
"5eef2ce24766d31909c0b269fe90c817a8f263fb",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
enum Model {
|
||||
Phi2(Phi2),
|
||||
Phi3(Phi3),
|
||||
Phi3b(Phi3b),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&mut self, xs: &Tensor, pos: usize) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Phi2(m) => m.forward(xs, pos),
|
||||
Self::Phi3(m) => m.forward(xs, pos),
|
||||
Self::Phi3b(m) => m.forward(xs, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut model = {
|
||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
match args.which {
|
||||
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
||||
Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?),
|
||||
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
|
||||
}
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||
print!("{}", &prompt_str);
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let tokens = tokens.get_ids();
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let eos_token = *tos
|
||||
.tokenizer()
|
||||
.get_vocab(true)
|
||||
.get("<|endoftext|>")
|
||||
.unwrap();
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
tokens.len(),
|
||||
tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
@ -67,10 +67,8 @@ enum Which {
|
||||
Mixtral,
|
||||
#[value(name = "mixtral-instruct")]
|
||||
MixtralInstruct,
|
||||
#[value(name = "llama3-8b")]
|
||||
L8b,
|
||||
#[value(name = "phi3")]
|
||||
Phi3,
|
||||
#[value(name = "phi-2")]
|
||||
Phi2,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -87,8 +85,7 @@ impl Which {
|
||||
| Self::L34bCode
|
||||
| Self::Leo7b
|
||||
| Self::Leo13b
|
||||
| Self::L8b
|
||||
| Self::Phi3 => false,
|
||||
| Self::Phi2 => false,
|
||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||
// same way. Starling is a fine tuned version of OpenChat.
|
||||
Self::OpenChat35
|
||||
@ -122,9 +119,8 @@ impl Which {
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::OpenChat35
|
||||
| Self::Starling7bAlpha
|
||||
| Self::L8b
|
||||
| Self::Phi3 => false,
|
||||
| Self::Phi2
|
||||
| Self::Starling7bAlpha => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
}
|
||||
@ -147,10 +143,9 @@ impl Which {
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::Phi2
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::L8b
|
||||
| Self::Phi3 => false,
|
||||
| Self::Zephyr7bBeta => false,
|
||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||
}
|
||||
}
|
||||
@ -177,8 +172,7 @@ impl Which {
|
||||
| Self::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
|
||||
Self::OpenChat35 => "openchat/openchat_3.5",
|
||||
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||
Self::L8b => "meta-llama/Meta-Llama-3-8B",
|
||||
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
||||
Self::Phi2 => "microsoft/phi-2",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -334,28 +328,11 @@ impl Args {
|
||||
"TheBloke/Starling-LM-7B-alpha-GGUF",
|
||||
"starling-lm-7b-alpha.Q4_K_M.gguf",
|
||||
),
|
||||
// TODO: swap to TheBloke model when available
|
||||
Which::L8b => (
|
||||
"QuantFactory/Meta-Llama-3-8B-GGUF",
|
||||
"Meta-Llama-3-8B.Q4_K_S.gguf",
|
||||
),
|
||||
Which::Phi3 => (
|
||||
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||
),
|
||||
};
|
||||
let revision = if self.which == Which::Phi3 {
|
||||
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
|
||||
} else {
|
||||
"main"
|
||||
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
revision.to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
let api = api.model(repo.to_string());
|
||||
api.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
@ -383,9 +360,6 @@ fn main() -> anyhow::Result<()> {
|
||||
#[cfg(feature = "cuda")]
|
||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||
|
||||
candle::cuda::set_gemm_reduced_precision_f16(true);
|
||||
candle::cuda::set_gemm_reduced_precision_bf16(true);
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
@ -454,8 +428,7 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L34bCode
|
||||
| Which::Leo7b
|
||||
| Which::Leo13b
|
||||
| Which::L8b
|
||||
| Which::Phi3 => 1,
|
||||
| Which::Phi2 => 1,
|
||||
Which::Mixtral
|
||||
| Which::MixtralInstruct
|
||||
| Which::Mistral7b
|
||||
@ -572,14 +545,11 @@ fn main() -> anyhow::Result<()> {
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
let eos_token = match args.which {
|
||||
Which::L8b => "<|end_of_text|>",
|
||||
_ => match args.which.is_open_chat() {
|
||||
true => "<|end_of_turn|>",
|
||||
false => "</s>",
|
||||
},
|
||||
let eos_token = if args.which.is_open_chat() {
|
||||
"<|end_of_turn|>"
|
||||
} else {
|
||||
"</s>"
|
||||
};
|
||||
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
|
@ -13,7 +13,7 @@ struct Block {
|
||||
|
||||
impl Block {
|
||||
fn get(&self, key: &str) -> Result<&str> {
|
||||
match self.parameters.get(key) {
|
||||
match self.parameters.get(&key.to_string()) {
|
||||
None => candle::bail!("cannot find {} in {}", key, self.block_type),
|
||||
Some(value) => Ok(value),
|
||||
}
|
||||
@ -28,7 +28,7 @@ pub struct Darknet {
|
||||
|
||||
impl Darknet {
|
||||
fn get(&self, key: &str) -> Result<&str> {
|
||||
match self.parameters.get(key) {
|
||||
match self.parameters.get(&key.to_string()) {
|
||||
None => candle::bail!("cannot find {} in net parameters", key),
|
||||
Some(value) => Ok(value),
|
||||
}
|
||||
|
@ -448,9 +448,9 @@ pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows10
|
||||
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
|
||||
///
|
||||
/// The integrated loudness measurement is not just the average power over the
|
||||
/// entire signal. BS.1770-4 defines two stages of gating that exclude
|
||||
/// entire signal. BS.1770-4 defines defines two stages of gating that exclude
|
||||
/// parts of the signal, to ensure that silent parts do not contribute to the
|
||||
/// loudness measurement. This function performs that gating, and returns the
|
||||
/// loudness measurment. This function performs that gating, and returns the
|
||||
/// average power over the windows that were not excluded.
|
||||
///
|
||||
/// The result of this function is the integrated loudness measurement.
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.5.1"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.1" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.5.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.5.1"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -6,6 +6,5 @@ pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
@ -765,21 +765,20 @@ static __device__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
y[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
extern "C" __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
|
||||
const int64_t i = blockIdx.x;
|
||||
const int i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int64_t ib = 8*i + ir;
|
||||
const int ib = 8*i + ir;
|
||||
if (ib >= nb32) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||
float * y = yy + 256*i + 32*ir + 4*il;
|
||||
|
||||
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
|
||||
const float d = __half2float(x->d);
|
||||
@ -793,21 +792,20 @@ static __device__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
extern "C" __global__ void dequantize_block_q4_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
|
||||
const int64_t i = blockIdx.x;
|
||||
const int i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int64_t ib = 8*i + ir;
|
||||
const int ib = 8*i + ir;
|
||||
if (ib >= nb32) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||
float * y = yy + 256*i + 32*ir + 4*il;
|
||||
|
||||
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
|
||||
const float2 d = __half22float2(x->dm);
|
||||
@ -822,8 +820,7 @@ static __device__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
|
||||
|
||||
//================================== k-quants
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
extern "C" __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
@ -835,7 +832,7 @@ static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
||||
const int is = 8*n + l/16;
|
||||
|
||||
const uint8_t q = x[i].qs[32*n + l];
|
||||
dst_t * y = yy + i*QK_K + 128*n;
|
||||
float * y = yy + i*QK_K + 128*n;
|
||||
|
||||
float dall = __low2half(x[i].dm);
|
||||
float dmin = __high2half(x[i].dm);
|
||||
@ -847,7 +844,7 @@ static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
||||
const int is = tid/16; // 0 or 1
|
||||
const int il = tid%16; // 0...15
|
||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||
float * y = yy + i*QK_K + 16*is + il;
|
||||
float dall = __low2half(x[i].dm);
|
||||
float dmin = __high2half(x[i].dm);
|
||||
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
||||
@ -856,8 +853,7 @@ static __device__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
||||
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
extern "C" __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
@ -881,7 +877,7 @@ static __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
|
||||
float d_all = x[i].d;
|
||||
float dl = d_all * (us - 32);
|
||||
|
||||
dst_t * y = yy + i*QK_K + 128*n + 32*j;
|
||||
float * y = yy + i*QK_K + 128*n + 32*j;
|
||||
const uint8_t * q = x[i].qs + 32*n;
|
||||
const uint8_t * hm = x[i].hmask;
|
||||
|
||||
@ -893,7 +889,7 @@ static __device__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
|
||||
const int im = il/8; // 0...1
|
||||
const int in = il%8; // 0...7
|
||||
|
||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||
float * y = yy + i*QK_K + 16*is + il;
|
||||
|
||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
||||
@ -921,8 +917,7 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
|
||||
}
|
||||
#endif
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
extern "C" __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -935,7 +930,7 @@ static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
|
||||
const int is = 2*il;
|
||||
const int n = 4;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||
float * y = yy + i*QK_K + 64*il + n*ir;
|
||||
|
||||
const float dall = __low2half(x[i].dm);
|
||||
const float dmin = __high2half(x[i].dm);
|
||||
@ -954,7 +949,7 @@ static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
|
||||
#else
|
||||
const int tid = threadIdx.x;
|
||||
const uint8_t * q = x[i].qs;
|
||||
dst_t * y = yy + i*QK_K;
|
||||
float * y = yy + i*QK_K;
|
||||
const float d = (float)x[i].dm[0];
|
||||
const float m = (float)x[i].dm[1];
|
||||
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
||||
@ -962,8 +957,7 @@ static __device__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
extern "C" __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -975,7 +969,7 @@ static __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
|
||||
const int ir = tid%16; // ir is in 0...15
|
||||
const int is = 2*il; // is is in 0...6
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||
float * y = yy + i*QK_K + 64*il + 2*ir;
|
||||
|
||||
const float dall = __low2half(x[i].dm);
|
||||
const float dmin = __high2half(x[i].dm);
|
||||
@ -1003,26 +997,25 @@ static __device__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
|
||||
const int is = tid/16; // 0 or 1
|
||||
const uint8_t h = x[i].qh[in] >> im;
|
||||
const float d = x[i].d;
|
||||
dst_t * y = yy + i*QK_K + tid;
|
||||
float * y = yy + i*QK_K + tid;
|
||||
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
||||
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
extern "C" __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int64_t i = blockIdx.x;
|
||||
const int i = blockIdx.x;
|
||||
#if QK_K == 256
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t ip = tid/32; // ip is 0 or 1
|
||||
const int64_t il = tid - 32*ip; // 0...32
|
||||
const int64_t is = 8*ip + il/16;
|
||||
const int tid = threadIdx.x;
|
||||
const int ip = tid/32; // ip is 0 or 1
|
||||
const int il = tid - 32*ip; // 0...32
|
||||
const int is = 8*ip + il/16;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 128*ip + il;
|
||||
float * y = yy + i*QK_K + 128*ip + il;
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
@ -1037,11 +1030,11 @@ static __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
||||
#else
|
||||
|
||||
// assume 32 threads
|
||||
const int64_t tid = threadIdx.x;
|
||||
const int64_t ip = tid/16; // 0 or 1
|
||||
const int64_t il = tid - 16*ip; // 0...15
|
||||
const int tid = threadIdx.x;
|
||||
const int ip = tid/16; // 0 or 1
|
||||
const int il = tid - 16*ip; // 0...15
|
||||
|
||||
dst_t * y = yy + i*QK_K + 16*ip + il;
|
||||
float * y = yy + i*QK_K + 16*ip + il;
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
@ -1054,8 +1047,7 @@ static __device__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q8_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
extern "C" __global__ void dequantize_block_q8_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
const int i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
@ -1067,7 +1059,7 @@ static __device__ void dequantize_block_q8_0(const void * __restrict__ vx, dst_t
|
||||
return;
|
||||
}
|
||||
|
||||
dst_t * y = yy + 256*i + 32*ir + 8*il;
|
||||
float * y = yy + 256*i + 32*ir + 8*il;
|
||||
|
||||
const block_q8_0 * x = (const block_q8_0 *)vx + ib;
|
||||
const float d = __half2float(x->d);
|
||||
@ -1079,8 +1071,7 @@ static __device__ void dequantize_block_q8_0(const void * __restrict__ vx, dst_t
|
||||
}
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q8_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
extern "C" __global__ void dequantize_block_q8_K(const void * __restrict__ vx, float * __restrict__ yy) {
|
||||
const block_q8_K * x = (const block_q8_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
@ -1092,7 +1083,7 @@ static __device__ void dequantize_block_q8_K(const void * __restrict__ vx, dst_t
|
||||
const int ir = tid%8;
|
||||
const int n = 8;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||
float * y = yy + i*QK_K + 64*il + n*ir;
|
||||
|
||||
const int8_t * q = x[i].qs + 64*il + n*ir;
|
||||
|
||||
@ -1107,43 +1098,14 @@ static __device__ void dequantize_block_q8_K(const void * __restrict__ vx, dst_t
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q5_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
extern "C" __global__ void dequantize_block_q5_0(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
return dequantize_block<QK5_0, QR5_0, dequantize_q5_0>(vx, yy, nb32);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static __device__ void dequantize_block_q5_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||
extern "C" __global__ void dequantize_block_q5_1(const void * __restrict__ vx, float * __restrict__ yy, int nb32) {
|
||||
return dequantize_block<QK5_1, QR5_1, dequantize_q5_1>(vx, yy, nb32);
|
||||
}
|
||||
|
||||
#define DEQUANTIZE_K(QNAME) \
|
||||
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y) { \
|
||||
dequantize_block_##QNAME(vx, y); \
|
||||
} \
|
||||
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y) { \
|
||||
dequantize_block_##QNAME(vx, y); \
|
||||
} \
|
||||
|
||||
#define DEQUANTIZE(QNAME) \
|
||||
extern "C" __global__ void dequantize_block_##QNAME##_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) { \
|
||||
dequantize_block_##QNAME(vx, y, k); \
|
||||
} \
|
||||
extern "C" __global__ void dequantize_block_##QNAME##_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) { \
|
||||
dequantize_block_##QNAME(vx, y, k); \
|
||||
} \
|
||||
|
||||
DEQUANTIZE_K(q2_K)
|
||||
DEQUANTIZE_K(q3_K)
|
||||
DEQUANTIZE_K(q4_K)
|
||||
DEQUANTIZE_K(q5_K)
|
||||
DEQUANTIZE_K(q6_K)
|
||||
DEQUANTIZE_K(q8_K)
|
||||
DEQUANTIZE(q4_0)
|
||||
DEQUANTIZE(q4_1)
|
||||
DEQUANTIZE(q5_0)
|
||||
DEQUANTIZE(q5_1)
|
||||
DEQUANTIZE(q8_0)
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __device__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||
@ -3010,330 +2972,6 @@ extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda4(
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 5
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda5(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<5, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 6
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda6(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<6, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 7
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda7(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<7, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
// batch size = 8
|
||||
extern "C" __global__ void mul_mat_vec_q4_0_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_1_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_0_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_1_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q8_0_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q2_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q3_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q4_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q5_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void mul_mat_vec_q6_K_q8_1_cuda8(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
mul_mat_vec_q<8, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
extern "C" __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
|
||||
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
|
@ -1,88 +0,0 @@
|
||||
// Adapted from https://github.com/ggerganov/llama.cpp/blob/master/ggml-cuda/argsort.cu
|
||||
#define SORT_ORDER_ASC 1
|
||||
#define SORT_ORDER_DESC 0
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
template<typename T>
|
||||
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||
T tmp = a;
|
||||
a = b;
|
||||
b = tmp;
|
||||
}
|
||||
|
||||
template<int order, typename T>
|
||||
static __device__ void k_argsort(const T * x, uint32_t * dst, const int ncols, int ncols_pad) {
|
||||
// bitonic sort
|
||||
int col = threadIdx.x;
|
||||
int row = blockIdx.y;
|
||||
|
||||
if (col >= ncols_pad) {
|
||||
return;
|
||||
}
|
||||
|
||||
const T * x_row = x + row * ncols;
|
||||
extern __shared__ int dst_row[];
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == SORT_ORDER_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
#define ASORT_OP(TYPENAME, RUST_NAME) \
|
||||
extern "C" __global__ void asort_asc_##RUST_NAME( \
|
||||
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
|
||||
) { \
|
||||
k_argsort<SORT_ORDER_ASC>(x, dst, ncols, ncols_pad); \
|
||||
} \
|
||||
extern "C" __global__ void asort_desc_##RUST_NAME( \
|
||||
const TYPENAME * x, uint32_t * dst, const int ncols, int ncols_pad \
|
||||
) { \
|
||||
k_argsort<SORT_ORDER_DESC>(x, dst, ncols, ncols_pad); \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
ASORT_OP(__nv_bfloat16, bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
ASORT_OP(__half, f16)
|
||||
#endif
|
||||
|
||||
ASORT_OP(float, f32)
|
||||
ASORT_OP(double, f64)
|
||||
ASORT_OP(uint8_t, u8)
|
||||
ASORT_OP(uint32_t, u32)
|
||||
ASORT_OP(int64_t, i64)
|
@ -60,11 +60,6 @@ __device__ __forceinline__ T silu_fwd(T x) {
|
||||
return x / (static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T sigmoid_fwd(T x) {
|
||||
return recipg(static_cast<T>(1) + expg(-x));
|
||||
}
|
||||
|
||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
@ -121,7 +116,6 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
|
||||
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -148,7 +142,6 @@ UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||
UNARY_OP(__half, usign_f16, sign_(x))
|
||||
UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||
@ -200,5 +193,3 @@ UNARY_OP1(float, upowf_f32, powg(x, param))
|
||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||
UNARY_OP(float, usign_f32, sign_(x))
|
||||
UNARY_OP(double, usign_f64, sign_(x))
|
||||
UNARY_OP(float, usigmoid_f32, sigmoid_fwd(x))
|
||||
UNARY_OP(double, usigmoid_f64, sigmoid_fwd(x))
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.5.1"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -21,7 +21,6 @@ const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const SORT: &str = include_str!("sort.metal");
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Source {
|
||||
@ -36,7 +35,6 @@ pub enum Source {
|
||||
Conv,
|
||||
Random,
|
||||
Quantized,
|
||||
Sort,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
@ -76,30 +74,6 @@ macro_rules! ops{
|
||||
}
|
||||
}
|
||||
|
||||
pub mod contiguous_tiled {
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
pub mod $name {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled"));
|
||||
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled"));
|
||||
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled"));
|
||||
pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled"));
|
||||
pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled"));
|
||||
pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled"));
|
||||
}
|
||||
)+
|
||||
pub mod copy {
|
||||
use super::Kernel;
|
||||
pub const FLOAT: Kernel = Kernel("copy_f32_tiled");
|
||||
pub const HALF: Kernel = Kernel("copy_f16_tiled");
|
||||
pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled");
|
||||
pub const I64: Kernel = Kernel("copy_i64_tiled");
|
||||
pub const U32: Kernel = Kernel("copy_u32_tiled");
|
||||
pub const U8: Kernel = Kernel("copy_u8_tiled");
|
||||
}
|
||||
}
|
||||
|
||||
pub mod strided {
|
||||
pub struct Kernel(pub &'static str);
|
||||
$(
|
||||
@ -129,7 +103,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip, silu, sign, sigmoid
|
||||
tanh, recip, silu, sign
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -199,7 +173,6 @@ impl Kernels {
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Sort => SORT,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
}
|
||||
}
|
||||
@ -294,6 +267,30 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, &input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_copy2d(
|
||||
device: &Device,
|
||||
@ -337,58 +334,6 @@ pub fn call_copy2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_contiguous_tiled(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous_tiled::Kernel,
|
||||
length: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let tile_size = 2;
|
||||
let tiles = (length + tile_size - 1) / tile_size;
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, &input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_contiguous(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: BufferOffset,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, &input, output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_strided(
|
||||
device: &Device,
|
||||
@ -402,13 +347,16 @@ pub fn call_unary_strided(
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
|
||||
|
||||
let width: usize = shape.iter().product();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
@ -462,10 +410,10 @@ pub fn call_binary_strided(
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
let width: usize = shape.iter().product();
|
||||
let length: usize = shape.iter().product();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
@ -479,12 +427,14 @@ pub fn call_binary_strided(
|
||||
output
|
||||
)
|
||||
);
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1749,7 +1699,7 @@ pub enum GgmlDType {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_quantized_matmul_mv_t(
|
||||
pub fn call_quantized_matmul_t(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
@ -1758,8 +1708,7 @@ pub fn call_quantized_matmul_mv_t(
|
||||
lhs: &Buffer,
|
||||
lhs_offset: usize,
|
||||
rhs: &Buffer,
|
||||
dst_offset: usize,
|
||||
dst: &Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// Everything is in reverse
|
||||
let ne00 = k as i64;
|
||||
@ -1799,9 +1748,8 @@ pub fn call_quantized_matmul_mv_t(
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
// Fixing a bug in Metal for GGML
|
||||
// https://github.com/ggerganov/llama.cpp/blob/b8109bc0139f15a5b321909f47510b89dca47ffc/ggml-metal.m#L1576
|
||||
let nth0 = 2;
|
||||
let nth1 = 32;
|
||||
let nth0 = 4;
|
||||
let nth1 = 8;
|
||||
let align = 4;
|
||||
(nth0, nth1, align)
|
||||
}
|
||||
@ -1873,7 +1821,7 @@ pub fn call_quantized_matmul_mv_t(
|
||||
(
|
||||
rhs,
|
||||
(lhs, lhs_offset),
|
||||
(dst, dst_offset),
|
||||
output,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
@ -1892,9 +1840,10 @@ pub fn call_quantized_matmul_mv_t(
|
||||
r3
|
||||
)
|
||||
);
|
||||
encoder.set_threadgroup_memory_length(0, 8192);
|
||||
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
|
||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
||||
encoder.end_encoding();
|
||||
@ -2051,42 +2000,5 @@ pub fn call_conv_transpose2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_arg_sort(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: &'static str,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
ncols_pad: usize,
|
||||
src: BufferOffset,
|
||||
dst: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sort, name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (&src, dst, ncols as i64, ncols_pad as i64));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: 1,
|
||||
height: nrows as u64,
|
||||
depth: 1,
|
||||
};
|
||||
let thread_group_size = MTLSize {
|
||||
width: ncols_pad as u64,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.use_resource(src.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(dst, metal::MTLResourceUsage::Write);
|
||||
encoder.set_threadgroup_memory_length(0, (ncols_pad * 4).max(16) as u64);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
@ -1,4 +1,3 @@
|
||||
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
|
||||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
@ -1,97 +0,0 @@
|
||||
// Imported from https://github.com/ggerganov/llama.cpp/blob/master/ggml-metal.metal
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
|
||||
#define SORT_ASC 1
|
||||
#define SORT_DESC 0
|
||||
|
||||
template<int order, typename T>
|
||||
METAL_FUNC void argsort(
|
||||
device const T * x,
|
||||
device uint32_t * dst,
|
||||
constant int64_t & ncols,
|
||||
constant int64_t & ncols_pad,
|
||||
threadgroup uint32_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
||||
int col = tpitg[0];
|
||||
int row = tgpig[1];
|
||||
|
||||
if (col >= ncols_pad) return;
|
||||
|
||||
device const T * x_row = x + row * ncols;
|
||||
threadgroup uint32_t * dst_row = shared_values;
|
||||
|
||||
// initialize indices
|
||||
dst_row[col] = col;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||
for (int j = k / 2; j > 0; j /= 2) {
|
||||
int ixj = col ^ j;
|
||||
if (ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= ncols ||
|
||||
(dst_row[ixj] < ncols && (order == SORT_ASC ?
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= ncols ||
|
||||
(dst_row[col] < ncols && (order == SORT_ASC ?
|
||||
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||
) {
|
||||
SWAP(dst_row[col], dst_row[ixj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
}
|
||||
|
||||
// copy the result to dst without the padding
|
||||
if (col < ncols) {
|
||||
dst[row * ncols + col] = dst_row[col];
|
||||
}
|
||||
}
|
||||
|
||||
#define ARGSORT(T, RUST_T) \
|
||||
kernel void asort_asc_##RUST_T( \
|
||||
device const T * x, \
|
||||
device uint32_t * dst, \
|
||||
constant int64_t & ncols, \
|
||||
constant int64_t & ncols_pad, \
|
||||
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
|
||||
uint3 tgpig[[threadgroup_position_in_grid]], \
|
||||
uint3 tpitg[[thread_position_in_threadgroup]] \
|
||||
) { \
|
||||
argsort<SORT_ASC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
|
||||
} \
|
||||
kernel void asort_desc_##RUST_T( \
|
||||
device const T * x, \
|
||||
device uint32_t * dst, \
|
||||
constant int64_t & ncols, \
|
||||
constant int64_t & ncols_pad, \
|
||||
threadgroup uint32_t * shared_values [[threadgroup(0)]], \
|
||||
uint3 tgpig[[threadgroup_position_in_grid]], \
|
||||
uint3 tpitg[[thread_position_in_threadgroup]] \
|
||||
) { \
|
||||
argsort<SORT_DESC, T>(x, dst, ncols, ncols_pad, shared_values, tgpig, tpitg); \
|
||||
} \
|
||||
|
||||
ARGSORT(float, f32)
|
||||
ARGSORT(half, f16)
|
||||
ARGSORT(uint8_t, u8)
|
||||
ARGSORT(uint32_t, u32)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
ARGSORT(int64_t, i64)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
ARGSORT(bfloat, bf16)
|
||||
#endif
|
@ -67,11 +67,6 @@ template <typename T> METAL_FUNC T relu(T in){
|
||||
template <typename T> METAL_FUNC T silu(T in){
|
||||
return in / (static_cast<T>(1) + exp(-in));
|
||||
}
|
||||
template <typename T> METAL_FUNC T sigmoid(T in) {
|
||||
return recip(static_cast<T>(1) + exp(-in));
|
||||
}
|
||||
|
||||
#define TILE_SIZE 2
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
@ -84,8 +79,8 @@ kernel void FN_NAME( \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = TYPENAME(FN(float(input[tid]))); \
|
||||
} \
|
||||
kernel void FN_NAME##_##strided( \
|
||||
}\
|
||||
kernel void FN_NAME_STRIDED( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
@ -98,17 +93,6 @@ kernel void FN_NAME##_##strided( \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
|
||||
} \
|
||||
kernel void FN_NAME##_##tiled( \
|
||||
constant size_t &dim, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
for (uint i = 0; i < TILE_SIZE; i++) { \
|
||||
const uint idx = tid * TILE_SIZE + i; \
|
||||
output[idx] = TYPENAME(FN(float(input[idx]))); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define UNARY_OP(NAME) \
|
||||
@ -158,7 +142,6 @@ UNARY_OP(tanh)
|
||||
UNARY_OP(recip)
|
||||
UNARY_OP(relu)
|
||||
UNARY_OP(sign)
|
||||
UNARY_OP(sigmoid)
|
||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||
@ -189,7 +172,6 @@ BFLOAT_UNARY_OP(tanh)
|
||||
BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
BFLOAT_UNARY_OP(sign)
|
||||
BFLOAT_UNARY_OP(sigmoid)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
|
||||
|
@ -28,7 +28,7 @@
|
||||
//! ```
|
||||
//!
|
||||
//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct LayerNormConfig {
|
||||
@ -105,7 +105,7 @@ impl LayerNorm {
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LayerNorm {
|
||||
impl crate::Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
@ -162,20 +162,11 @@ impl RmsNorm {
|
||||
pub fn into_inner(self) -> LayerNorm {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
|
||||
pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
impl crate::Module for RmsNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
if xs.is_contiguous() {
|
||||
crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
|
||||
} else {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -43,193 +43,9 @@ pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
&xs[0].silu()? * &xs[1]
|
||||
}
|
||||
|
||||
struct Sigmoid;
|
||||
|
||||
impl candle::CustomOp1 for Sigmoid {
|
||||
fn name(&self) -> &'static str {
|
||||
"sigmoid"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
|
||||
fn fwd<T: num_traits::Float>(v: T) -> T {
|
||||
(v.neg().exp() + T::one()).recip()
|
||||
}
|
||||
|
||||
// FIXME: using `candle::map_dtype` causes compilation errors.
|
||||
let storage = match storage {
|
||||
CpuStorage::BF16(slice) => {
|
||||
CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
CpuStorage::F16(slice) => {
|
||||
CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
CpuStorage::F32(slice) => {
|
||||
CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
CpuStorage::F64(slice) => {
|
||||
CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
_ => Err(candle::Error::UnsupportedDTypeForOp(
|
||||
storage.dtype(),
|
||||
self.name(),
|
||||
))?,
|
||||
};
|
||||
Ok((storage, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &candle::CudaStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use candle::cuda_backend::SlicePtrOrNull;
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
struct S;
|
||||
impl Map1 for S {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||
|
||||
let params = (el_count, dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
let dev = storage.device();
|
||||
let slice = S.map(&storage.slice, dev, layout)?;
|
||||
let dst = candle::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::MetalError;
|
||||
let device = storage.device();
|
||||
let dtype = storage.dtype();
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label("sigmoid");
|
||||
let src = candle_metal_kernels::BufferOffset {
|
||||
buffer: storage.buffer(),
|
||||
offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
};
|
||||
|
||||
match (el_count % 2, dtype, layout.is_contiguous()) {
|
||||
(0, DType::BF16 | DType::F16, true) => {
|
||||
use candle_metal_kernels::unary::contiguous_tiled;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous_tiled::sigmoid::HALF,
|
||||
DType::F32 => contiguous_tiled::sigmoid::FLOAT,
|
||||
DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
|
||||
dtype => {
|
||||
candle::bail!(
|
||||
"Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
|
||||
)
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous_tiled(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, true) => {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous::sigmoid::HALF,
|
||||
DType::F32 => contiguous::sigmoid::FLOAT,
|
||||
DType::BF16 => contiguous::sigmoid::BFLOAT,
|
||||
dtype => {
|
||||
candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, false) => {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => strided::sigmoid::HALF,
|
||||
DType::F32 => strided::sigmoid::FLOAT,
|
||||
DType::BF16 => strided::sigmoid::BFLOAT,
|
||||
dtype => {
|
||||
candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
src,
|
||||
layout.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
}
|
||||
|
||||
let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
|
||||
Ok((new_storage, layout.shape().clone()))
|
||||
}
|
||||
|
||||
fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
// d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
|
||||
let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
|
||||
Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply_op1(Sigmoid)
|
||||
// TODO: Should we have a specialized op for this?
|
||||
(xs.neg()?.exp()? + 1.0)?.recip()
|
||||
}
|
||||
|
||||
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
@ -254,7 +70,7 @@ pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
|
||||
let rand = Tensor::rand(0f32, 1f32, xs.shape(), xs.device())?;
|
||||
let scale = 1.0 / (1.0 - drop_p as f64);
|
||||
let drop_p = Tensor::new(drop_p, xs.device())?.broadcast_as(xs.shape())?;
|
||||
let mask = (rand.ge(&drop_p)?.to_dtype(xs.dtype())? * scale)?;
|
||||
let mask = (rand.ge(&drop_p)? * scale)?.to_dtype(xs.dtype())?;
|
||||
xs * mask
|
||||
}
|
||||
|
||||
|
@ -264,7 +264,7 @@ impl SimpleBackend for VarMap {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SafeTensorWithRouting<'a> {
|
||||
struct SafeTensorWithRouting<'a> {
|
||||
routing: HashMap<String, usize>,
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{test_utils, DType, Device, Tensor};
|
||||
use candle_nn::{batch_norm, BatchNorm, BatchNormConfig, VarBuilder, VarMap};
|
||||
use candle_nn::BatchNorm;
|
||||
|
||||
/* The test below has been generated using the following PyTorch code:
|
||||
import torch
|
||||
@ -20,7 +20,7 @@ print(m.running_mean)
|
||||
print(m.running_var)
|
||||
*/
|
||||
#[test]
|
||||
fn batch_norm_test() -> Result<()> {
|
||||
fn batch_norm() -> Result<()> {
|
||||
let running_mean = Tensor::zeros(5, DType::F32, &Device::Cpu)?;
|
||||
let running_var = Tensor::ones(5, DType::F32, &Device::Cpu)?;
|
||||
let bn = BatchNorm::new_no_bias(5, running_mean.clone(), running_var.clone(), 1e-8)?;
|
||||
@ -84,45 +84,3 @@ fn batch_norm_test() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// This test makes sure that we can train a batch norm layer using a VarMap.
|
||||
#[test]
|
||||
fn train_batch_norm() -> Result<()> {
|
||||
let vm = VarMap::new();
|
||||
let vb = VarBuilder::from_varmap(&vm, DType::F32, &Device::Cpu);
|
||||
let bn = batch_norm(1, BatchNormConfig::default(), vb)?;
|
||||
// Get a copy of the original mean to ensure it is being updated.
|
||||
let original_mean = bn.running_mean().detach().copy()?;
|
||||
let var_map_mean = {
|
||||
vm.data()
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get("running_mean")
|
||||
.unwrap()
|
||||
.clone()
|
||||
};
|
||||
// Ensure the var map mean is the same as the running mean.
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(bn.running_mean(), 4)?,
|
||||
test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?,
|
||||
);
|
||||
// Train with a something guaranteed to be different from the running mean.
|
||||
let mean_plus_one = {
|
||||
let one = original_mean.ones_like()?;
|
||||
original_mean.add(&one)?.reshape((1, 1))?
|
||||
};
|
||||
|
||||
bn.forward_train(&mean_plus_one)?;
|
||||
// Assert that the running mean has been updated.
|
||||
assert_ne!(
|
||||
test_utils::to_vec1_round(bn.running_mean(), 4)?,
|
||||
test_utils::to_vec1_round(&original_mean, 4)?,
|
||||
);
|
||||
|
||||
// Assert that the var map mean has been updated.
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(bn.running_mean(), 4)?,
|
||||
test_utils::to_vec1_round(var_map_mean.as_tensor(), 4)?,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -170,19 +170,8 @@ fn rope_thd(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sigmoid(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let s1 = candle_nn::ops::sigmoid(&tensor)?;
|
||||
let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?;
|
||||
let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
|
||||
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
||||
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
|
||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.5.1"
|
||||
version = "0.5.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.5.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.5.1" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.5.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.5.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -23,11 +23,6 @@ trait Attr {
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
|
||||
}
|
||||
|
||||
trait AttrOwned: Sized {
|
||||
const TYPE: AttributeType;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self>;
|
||||
}
|
||||
|
||||
impl Attr for i64 {
|
||||
const TYPE: AttributeType = AttributeType::Int;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
@ -56,50 +51,6 @@ impl Attr for str {
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Tensor {
|
||||
const TYPE: AttributeType = AttributeType::Tensor;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
let tensor_proto = match &attr.t {
|
||||
Some(value) => value,
|
||||
None => bail!(
|
||||
"attribute {} was of type TENSOR, but no tensor was found",
|
||||
attr.name
|
||||
),
|
||||
};
|
||||
|
||||
let data_type = match DataType::try_from(tensor_proto.data_type) {
|
||||
Ok(value) => value,
|
||||
Err(_) => bail!(
|
||||
"attribute {} of type TENSOR was an invalid data_type number {}",
|
||||
attr.name,
|
||||
tensor_proto.data_type
|
||||
),
|
||||
};
|
||||
|
||||
let dtype = match dtype(data_type) {
|
||||
Some(value) => value,
|
||||
None => bail!(
|
||||
"attribute {} of type TENSOR has an unsupported data_type {}",
|
||||
attr.name,
|
||||
data_type.as_str_name()
|
||||
),
|
||||
};
|
||||
|
||||
let mut dims = Vec::with_capacity(tensor_proto.dims.len());
|
||||
for dim in &tensor_proto.dims {
|
||||
if dim < &0 {
|
||||
bail!(
|
||||
"attribute {} of type TENSOR has a negative dimension, which is unsupported",
|
||||
attr.name
|
||||
)
|
||||
}
|
||||
dims.push(*dim as usize)
|
||||
}
|
||||
|
||||
Tensor::from_raw_buffer(&tensor_proto.raw_data, dtype, &dims, &Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
|
||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => {
|
||||
@ -147,24 +98,6 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
|
||||
}
|
||||
}
|
||||
|
||||
fn get_attr_opt_owned<T: AttrOwned>(node: &onnx::NodeProto, name: &str) -> Result<Option<T>> {
|
||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => Ok(None),
|
||||
Some(attr) => {
|
||||
if attr.r#type() != T::TYPE {
|
||||
bail!(
|
||||
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
|
||||
attr.r#type,
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
let val = T::get(attr)?;
|
||||
Ok(Some(val))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||
match DataType::try_from(t.data_type) {
|
||||
@ -327,11 +260,6 @@ pub fn simple_eval(
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Exp" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let output = xs.exp()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Equal" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
@ -530,17 +458,14 @@ pub fn simple_eval(
|
||||
}
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape
|
||||
"ConstantOfShape" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let value = get_attr_opt_owned::<Tensor>(node, "value")?.unwrap_or(Tensor::zeros(
|
||||
(),
|
||||
DType::F32,
|
||||
&Device::Cpu,
|
||||
)?);
|
||||
|
||||
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
|
||||
.broadcast_mul(&value)?;
|
||||
let dims = get(&node.input[0])?;
|
||||
let shape = dims
|
||||
.to_vec1::<i64>()?
|
||||
.into_iter()
|
||||
.map(|v| v as usize)
|
||||
.collect::<Vec<_>>();
|
||||
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Unsqueeze" => {
|
||||
@ -627,82 +552,6 @@ pub fn simple_eval(
|
||||
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
||||
values.insert(node.output[0].clone(), dims);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
|
||||
"Sqrt" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let output = xs.sqrt()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range
|
||||
"Range" => {
|
||||
let start = get(&node.input[0])?;
|
||||
let limit = get(&node.input[1])?;
|
||||
let delta = get(&node.input[2])?;
|
||||
|
||||
macro_rules! arange_step {
|
||||
($t: ty) => {
|
||||
Tensor::arange_step(
|
||||
start.to_vec0::<$t>()?,
|
||||
limit.to_vec0::<$t>()?,
|
||||
delta.to_vec0::<$t>()?,
|
||||
&Device::Cpu,
|
||||
)?
|
||||
};
|
||||
}
|
||||
|
||||
let output = match start.dtype() {
|
||||
DType::U8 => arange_step!(u8),
|
||||
DType::U32 => arange_step!(u32),
|
||||
DType::I64 => arange_step!(i64),
|
||||
DType::BF16 => arange_step!(f32),
|
||||
DType::F16 => arange_step!(f32),
|
||||
DType::F32 => arange_step!(f32),
|
||||
DType::F64 => arange_step!(f64),
|
||||
};
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater
|
||||
"Greater" => {
|
||||
let a = get(&node.input[0])?;
|
||||
let b = get(&node.input[1])?;
|
||||
|
||||
let output = a.broadcast_gt(b)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Less
|
||||
"Less" => {
|
||||
let a = get(&node.input[0])?;
|
||||
let b = get(&node.input[1])?;
|
||||
|
||||
let output = a.broadcast_lt(b)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Log
|
||||
"Log" => {
|
||||
let a = get(&node.input[0])?;
|
||||
|
||||
let output = a.log()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Min
|
||||
"Min" => {
|
||||
let mut output = get(&node.input[0])?.clone();
|
||||
for input in node.input.iter() {
|
||||
let input = get(input)?;
|
||||
output = output.broadcast_minimum(input)?
|
||||
}
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where
|
||||
"Where" => {
|
||||
let cond = get(&node.input[0])?;
|
||||
let a = get(&node.input[1])?;
|
||||
let b = get(&node.input[2])?;
|
||||
let output = cond.where_cond(a, b)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Conv" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
@ -971,46 +820,6 @@ pub fn simple_eval(
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"RandomUniform" => {
|
||||
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
||||
// type by
|
||||
// default
|
||||
let dtype = match DataType::try_from(dt as i32) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(DType::U8 | DType::U32 | DType::I64) => {
|
||||
bail!(
|
||||
"unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
bail!(
|
||||
"unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
bail!(
|
||||
"unsupported 'dtype' value {dt:?} for RandomUniform {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
};
|
||||
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
|
||||
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
|
||||
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
|
||||
if seed.is_some() {
|
||||
bail!("seed for RandomUniform is currently not supported")
|
||||
};
|
||||
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
|
||||
.iter()
|
||||
.map(|x| *x as usize)
|
||||
.collect();
|
||||
let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -4,16 +4,12 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::{DType, Device, NdArray, Result, Tensor};
|
||||
use candle_onnx::onnx;
|
||||
use candle_onnx::onnx::attribute_proto::AttributeType;
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
use candle::{Device, NdArray, Result, Tensor};
|
||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use std::collections::HashMap;
|
||||
|
||||
const INPUT_X: &str = "x";
|
||||
const INPUT_Y: &str = "y";
|
||||
const INPUT_A: &str = "a";
|
||||
const OUTPUT_Z: &str = "z";
|
||||
|
||||
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
||||
@ -231,52 +227,6 @@ fn test_div_operation() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Exp"
|
||||
#[test]
|
||||
fn test_exp_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Exp".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(vec![-1.0f32, 0.0f32, 1.0f32, 2.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results[0][0], 0.36787944f32);
|
||||
assert_eq!(results[0][1], 1.0f32);
|
||||
assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Equal"
|
||||
#[test]
|
||||
fn test_equal_operation() -> Result<()> {
|
||||
@ -870,137 +820,7 @@ fn test_flatten_operation() -> Result<()> {
|
||||
// #[test]
|
||||
|
||||
// "ConstantOfShape"
|
||||
#[test]
|
||||
fn test_constant_of_shape() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||
test(&[0.], Some(0i64), &[0i64])?;
|
||||
|
||||
// "value" defaults to 0 f32
|
||||
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
|
||||
|
||||
fn test(
|
||||
input: impl NdArray,
|
||||
value: Option<impl NdArray>,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let mut attribute = vec![];
|
||||
|
||||
if let Some(value) = value {
|
||||
let tensor = Tensor::new(value, &Device::Cpu)?;
|
||||
|
||||
let (value, data_type) = match tensor.dtype() {
|
||||
DType::U8 => (
|
||||
tensor.to_vec0::<u8>()?.to_le_bytes().to_vec(),
|
||||
DataType::Uint8,
|
||||
),
|
||||
DType::U32 => (
|
||||
tensor.to_vec0::<u32>()?.to_le_bytes().to_vec(),
|
||||
DataType::Uint32,
|
||||
),
|
||||
DType::I64 => (
|
||||
tensor.to_vec0::<i64>()?.to_le_bytes().to_vec(),
|
||||
DataType::Int64,
|
||||
),
|
||||
DType::F32 => (
|
||||
tensor.to_vec0::<f32>()?.to_le_bytes().to_vec(),
|
||||
DataType::Float,
|
||||
),
|
||||
DType::F64 => (
|
||||
tensor.to_vec0::<f64>()?.to_le_bytes().to_vec(),
|
||||
DataType::Double,
|
||||
),
|
||||
_ => panic!("unsupported DType in test"),
|
||||
};
|
||||
let tensor = onnx::TensorProto {
|
||||
data_type: data_type.into(),
|
||||
dims: tensor.dims().iter().map(|v| *v as i64).collect(),
|
||||
raw_data: value,
|
||||
segment: None,
|
||||
float_data: vec![],
|
||||
int32_data: vec![],
|
||||
string_data: vec![],
|
||||
int64_data: vec![],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
external_data: vec![],
|
||||
data_location: 0,
|
||||
double_data: vec![],
|
||||
uint64_data: vec![],
|
||||
};
|
||||
|
||||
attribute.push(AttributeProto {
|
||||
name: "value".to_string(),
|
||||
ref_attr_name: "value".to_string(),
|
||||
i: 0,
|
||||
doc_string: "value".to_string(),
|
||||
r#type: AttributeType::Tensor.into(),
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: Some(tensor),
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "ConstantOfShape".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute,
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval
|
||||
.get(OUTPUT_Z)
|
||||
.expect("Output 'z' not found")
|
||||
.to_dtype(DType::F64)?;
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
// #[test]
|
||||
|
||||
// "Unsqueeze"
|
||||
// #[test]
|
||||
@ -1819,596 +1639,3 @@ fn test_reduce_mean() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Sqrt"
|
||||
#[test]
|
||||
fn test_sqrt() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-155
|
||||
test(&[1., 4., 9.], &[1., 2., 3.])?;
|
||||
|
||||
fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Sqrt".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "RandomUniform"
|
||||
#[test]
|
||||
fn test_random_uniform() -> Result<()> {
|
||||
test(vec![3, 2, 1, 4], None, None)?;
|
||||
test(vec![2, 2, 2, 2], Some(-10.0), None)?;
|
||||
test(vec![2, 2, 2, 2], None, Some(10.0))?;
|
||||
test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?;
|
||||
|
||||
fn test(shape: Vec<i64>, low: Option<f32>, high: Option<f32>) -> Result<()> {
|
||||
let att_low = AttributeProto {
|
||||
name: "low".to_string(),
|
||||
ref_attr_name: "low".to_string(),
|
||||
i: 0,
|
||||
doc_string: "low".to_string(),
|
||||
r#type: 1, // FLOAT
|
||||
f: low.unwrap_or(0.0),
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let att_high = AttributeProto {
|
||||
name: "high".to_string(),
|
||||
ref_attr_name: "high".to_string(),
|
||||
i: 0,
|
||||
doc_string: "high".to_string(),
|
||||
r#type: 1, // FLOAT
|
||||
f: high.unwrap_or(1.0),
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let att_shape = AttributeProto {
|
||||
name: "shape".to_string(),
|
||||
ref_attr_name: "shape".to_string(),
|
||||
i: 0,
|
||||
doc_string: "shape".to_string(),
|
||||
r#type: 7, // INTS
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: shape,
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let att_dtype = AttributeProto {
|
||||
name: "dtype".to_string(),
|
||||
ref_attr_name: "dtype".to_string(),
|
||||
i: 11, // DOUBLE
|
||||
doc_string: "dtype".to_string(),
|
||||
r#type: 2, // INT
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let attrs = {
|
||||
let mut mut_attrs = vec![att_shape, att_dtype];
|
||||
if low.is_some() {
|
||||
mut_attrs.push(att_low);
|
||||
}
|
||||
if high.is_some() {
|
||||
mut_attrs.push(att_high);
|
||||
}
|
||||
mut_attrs
|
||||
};
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "RandomUniform".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: attrs,
|
||||
input: vec![],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let min = z
|
||||
.flatten_all()?
|
||||
.to_vec1()?
|
||||
.into_iter()
|
||||
.reduce(f64::min)
|
||||
.unwrap();
|
||||
let max = z
|
||||
.flatten_all()?
|
||||
.to_vec1()?
|
||||
.into_iter()
|
||||
.reduce(f64::max)
|
||||
.unwrap();
|
||||
assert!(min >= low.unwrap_or(0.0).into());
|
||||
assert!(max <= high.unwrap_or(1.0).into());
|
||||
assert_ne!(min, max);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Range"
|
||||
#[test]
|
||||
fn test_range() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
|
||||
test(1., 5., 2., &[1., 3.])?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
|
||||
test(10i64, 6i64, -3i64, &[10i64, 7i64])?;
|
||||
|
||||
fn test(
|
||||
start: impl NdArray,
|
||||
limit: impl NdArray,
|
||||
delta: impl NdArray,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Range".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![
|
||||
INPUT_X.to_string(),
|
||||
INPUT_Y.to_string(),
|
||||
INPUT_A.to_string(),
|
||||
],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(start, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(limit, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_A.to_string(), Tensor::new(delta, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval
|
||||
.get(OUTPUT_Z)
|
||||
.expect("Output 'z' not found")
|
||||
.to_dtype(DType::F64)?;
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Greater"
|
||||
#[test]
|
||||
fn test_greater() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
|
||||
test(&[1., 2., 3.], &[3., 2., 1.], &[0u8, 0, 1])?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
|
||||
test(&[1., 2., 3.], 2., &[0u8, 0, 1])?;
|
||||
|
||||
fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Greater".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval
|
||||
.get(OUTPUT_Z)
|
||||
.expect("Output 'z' not found")
|
||||
.to_dtype(DType::F64)?;
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Less"
|
||||
#[test]
|
||||
fn test_less() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
|
||||
test(&[1., 2., 3.], &[3., 2., 1.], &[1u8, 0, 0])?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
|
||||
test(&[1., 2., 3.], 2., &[1u8, 0, 0])?;
|
||||
|
||||
fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Less".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval
|
||||
.get(OUTPUT_Z)
|
||||
.expect("Output 'z' not found")
|
||||
.to_dtype(DType::F64)?;
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Log"
|
||||
#[test]
|
||||
fn test_log() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-82
|
||||
test(&[1., 10.], &[0., std::f64::consts::LN_10])?;
|
||||
|
||||
fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Log".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Min"
|
||||
#[test]
|
||||
fn test_min() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-94
|
||||
test(&[3., 2., 1.], &[1., 4., 4.], &[2., 5., 0.], &[1., 2., 0.])?;
|
||||
|
||||
fn test(
|
||||
a: impl NdArray,
|
||||
b: impl NdArray,
|
||||
c: impl NdArray,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Min".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![
|
||||
INPUT_X.to_string(),
|
||||
INPUT_Y.to_string(),
|
||||
INPUT_A.to_string(),
|
||||
],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_A.to_string(), Tensor::new(c, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Where"
|
||||
#[test]
|
||||
fn test_where() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
|
||||
test(
|
||||
&[[1u8, 0], [1, 1]],
|
||||
&[[1i64, 2], [3, 4]],
|
||||
&[[9i64, 8], [7, 6]],
|
||||
&[[1i64, 8], [3, 4]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
|
||||
test(
|
||||
&[[1u8, 0], [1, 1]],
|
||||
&[[1., 2.], [3., 4.]],
|
||||
&[[9., 8.], [7., 6.]],
|
||||
&[[1., 8.], [3., 4.]],
|
||||
)?;
|
||||
|
||||
fn test(
|
||||
condition: impl NdArray,
|
||||
x: impl NdArray,
|
||||
y: impl NdArray,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Where".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![
|
||||
INPUT_X.to_string(),
|
||||
INPUT_Y.to_string(),
|
||||
INPUT_A.to_string(),
|
||||
],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(condition, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(x, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_A.to_string(), Tensor::new(y, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval
|
||||
.get(OUTPUT_Z)
|
||||
.expect("Output 'z' not found")
|
||||
.to_dtype(DType::F64)?;
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle::{DType, Error, IndexOp, Result, Tensor, D};
|
||||
use candle::{DType, Error, Result, Tensor};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
@ -73,15 +73,17 @@ impl LogitsProcessor {
|
||||
}
|
||||
|
||||
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||
fn sample_topk(&mut self, logits: &Tensor, top_k: usize, temperature: f64) -> Result<u32> {
|
||||
let arg_sort = logits.arg_sort_last_dim(false)?;
|
||||
let top_k_indices = arg_sort.narrow(candle::D::Minus1, 0, top_k)?;
|
||||
let top_k_logits = logits.gather(&top_k_indices, D::Minus1)?;
|
||||
let top_k_logits = (&top_k_logits / temperature)?;
|
||||
let top_k_prs = candle_nn::ops::softmax_last_dim(&top_k_logits)?;
|
||||
let top_k_prs = top_k_prs.to_vec1()?;
|
||||
let index = self.sample_multinomial(&top_k_prs)?;
|
||||
Ok(top_k_indices.i(index as usize)?.to_vec0::<u32>()?)
|
||||
fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
|
||||
if top_k >= prs.len() {
|
||||
self.sample_multinomial(prs)
|
||||
} else {
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
let (indices, _, _) =
|
||||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||
let index = self.sample_multinomial(&prs)?;
|
||||
Ok(indices[index as usize] as u32)
|
||||
}
|
||||
}
|
||||
|
||||
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||
@ -135,12 +137,8 @@ impl LogitsProcessor {
|
||||
}
|
||||
}
|
||||
Sampling::TopK { k, temperature } => {
|
||||
if *k >= logits.dim(D::Minus1)? {
|
||||
let prs = prs(*temperature)?;
|
||||
self.sample_multinomial(&prs)?
|
||||
} else {
|
||||
self.sample_topk(&logits, *k, *temperature)?
|
||||
}
|
||||
let mut prs = prs(*temperature)?;
|
||||
self.sample_topk(&mut prs, *k)?
|
||||
}
|
||||
Sampling::TopKThenTopP { k, p, temperature } => {
|
||||
let mut prs = prs(*temperature)?;
|
||||
|
@ -227,9 +227,8 @@ impl Attention {
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
|
@ -16,14 +16,6 @@ pub struct LlamaConfig {
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
pub fn num_key_value_heads(&self) -> usize {
|
||||
self.num_key_value_heads.unwrap_or(self.num_attention_heads)
|
||||
}
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
@ -38,12 +30,10 @@ impl LlamaConfig {
|
||||
vocab_size: self.vocab_size,
|
||||
num_hidden_layers: self.num_hidden_layers,
|
||||
num_attention_heads: self.num_attention_heads,
|
||||
num_key_value_heads: self.num_key_value_heads(),
|
||||
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
bos_token_id: self.bos_token_id,
|
||||
eos_token_id: self.eos_token_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -59,8 +49,6 @@ pub struct Config {
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -75,8 +63,6 @@ impl Config {
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-6,
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,8 +77,6 @@ impl Config {
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
bos_token_id: None,
|
||||
eos_token_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -122,6 +106,7 @@ impl Cache {
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
@ -181,10 +166,16 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (_b_sz, _, seq_len, _hidden_size) = x.dims4()?;
|
||||
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
|
||||
let cos = cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
candle_nn::rotary_emb::rope(x, &cos, &sin)
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(
|
||||
@ -202,12 +193,10 @@ impl CausalSelfAttention {
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
@ -26,10 +26,8 @@ pub mod mixtral;
|
||||
pub mod mobileone;
|
||||
pub mod moondream;
|
||||
pub mod mpt;
|
||||
pub mod olmo;
|
||||
pub mod persimmon;
|
||||
pub mod phi;
|
||||
pub mod phi3;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_llama;
|
||||
@ -39,8 +37,6 @@ pub mod quantized_mistral;
|
||||
pub mod quantized_mixformer;
|
||||
pub mod quantized_moondream;
|
||||
pub mod quantized_mpt;
|
||||
pub mod quantized_phi;
|
||||
pub mod quantized_phi3;
|
||||
pub mod quantized_recurrent_gemma;
|
||||
pub mod quantized_rwkv_v5;
|
||||
pub mod quantized_rwkv_v6;
|
||||
|
@ -302,7 +302,6 @@ impl Module for VisionEncoder {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
pub text_model: PhiModel,
|
||||
pub vision_encoder: VisionEncoder,
|
||||
|
@ -1,337 +0,0 @@
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b, linear_no_bias, Activation, LayerNorm, Linear, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub attention_bias: bool,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub max_position_embeddings: usize,
|
||||
pub rope_theta: f64,
|
||||
pub tie_word_embeddings: bool,
|
||||
pub clip_qkv: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
hidden_size: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
qkv_clip: Option<f64>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = hidden_sz / num_heads;
|
||||
let b = cfg.attention_bias;
|
||||
let qkv_clip = cfg.clip_qkv;
|
||||
let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
hidden_size: hidden_sz,
|
||||
rotary_emb,
|
||||
qkv_clip,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let (query_states, key_states, value_states) = match &self.qkv_clip {
|
||||
None => (query_states, key_states, value_states),
|
||||
Some(qkv_clip) => {
|
||||
let query_states = Tensor::clamp(&query_states, -qkv_clip, *qkv_clip)?;
|
||||
let key_states = Tensor::clamp(&key_states, -qkv_clip, *qkv_clip)?;
|
||||
let value_states = Tensor::clamp(&value_states, -qkv_clip, *qkv_clip)?;
|
||||
(query_states, key_states, value_states)
|
||||
}
|
||||
};
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.hidden_size))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: LayerNorm,
|
||||
post_attention_layernorm: LayerNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let ln_weight = Tensor::ones(cfg.hidden_size, vb.dtype(), vb.device())?;
|
||||
let input_layernorm = LayerNorm::new_no_bias(ln_weight.clone(), 1e-5);
|
||||
let post_attention_layernorm = LayerNorm::new_no_bias(ln_weight.clone(), 1e-5);
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: LayerNorm,
|
||||
lm_head: Linear,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let ln_weight = Tensor::ones(cfg.hidden_size, vb.dtype(), vb.device())?;
|
||||
let norm = LayerNorm::new_no_bias(ln_weight, 1e-5);
|
||||
let lm_head = if cfg.tie_word_embeddings {
|
||||
Linear::new(embed_tokens.embeddings().clone(), None)
|
||||
} else {
|
||||
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||
};
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
// Sliding window mask?
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
@ -72,7 +72,7 @@ impl RotaryEmbedding {
|
||||
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
||||
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let rotate_half = Tensor::cat(&[&xs2.neg()?, xs1], D::Minus1)?;
|
||||
let rotate_half = Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)?;
|
||||
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
|
||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||
}
|
||||
|
@ -1,329 +0,0 @@
|
||||
// This implementation is based on:
|
||||
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
|
||||
use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
use std::sync::Arc;
|
||||
|
||||
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub bos_token_id: Option<u32>,
|
||||
pub eos_token_id: Option<u32>,
|
||||
pub rope_scaling: Option<String>,
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn head_dim(&self) -> usize {
|
||||
self.hidden_size / self.num_attention_heads
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim();
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
qkv_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let head_dim = cfg.head_dim();
|
||||
let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
|
||||
let qkv_proj = linear(cfg.hidden_size, op_size, vb.pp("qkv_proj"))?;
|
||||
let o_proj = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
qkv_proj,
|
||||
o_proj,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups: num_heads / num_kv_heads,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let qkv = self.qkv_proj.forward(xs)?;
|
||||
let query_pos = self.num_heads * self.head_dim;
|
||||
let query_states = qkv.narrow(D::Minus1, 0, query_pos)?;
|
||||
let key_states = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
|
||||
let value_states = qkv.narrow(
|
||||
D::Minus1,
|
||||
query_pos + self.num_kv_heads * self.head_dim,
|
||||
self.num_kv_heads * self.head_dim,
|
||||
)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, ()))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
gate_up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: candle_nn::Activation,
|
||||
i_size: usize,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let i_size = cfg.intermediate_size;
|
||||
let gate_up_proj = linear(hidden_size, 2 * i_size, vb.pp("gate_up_proj"))?;
|
||||
let down_proj = linear(i_size, hidden_size, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
i_size,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let up_states = xs.apply(&self.gate_up_proj)?;
|
||||
let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
|
||||
let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
|
||||
let up_states = (up_states * gate.apply(&self.act_fn))?;
|
||||
up_states.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: Mlp,
|
||||
input_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let post_attention_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::quantized_nn::RmsNorm;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
@ -29,13 +28,13 @@ impl QMatMul {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
struct MlpSilu {
|
||||
feed_forward_w1: QMatMul,
|
||||
feed_forward_w2: QMatMul,
|
||||
feed_forward_w3: QMatMul,
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
impl Module for MlpSilu {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let w1 = self.feed_forward_w1.forward(xs)?;
|
||||
let w3 = self.feed_forward_w3.forward(xs)?;
|
||||
@ -45,16 +44,31 @@ impl Module for Mlp {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum MlpOrMoe {
|
||||
Mlp(Mlp),
|
||||
struct MlpSimple {
|
||||
fc1: QMatMul,
|
||||
fc2: QMatMul,
|
||||
act: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl Module for MlpSimple {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.fc1.forward(xs)?.apply(&self.act)?;
|
||||
self.fc2.forward(&xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Mlp {
|
||||
Silu(MlpSilu),
|
||||
Simple(MlpSimple),
|
||||
MoE {
|
||||
n_expert_used: usize,
|
||||
feed_forward_gate_inp: QMatMul,
|
||||
experts: Vec<Mlp>,
|
||||
experts: Vec<MlpSilu>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Module for MlpOrMoe {
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::MoE {
|
||||
@ -119,20 +133,48 @@ impl Module for MlpOrMoe {
|
||||
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
|
||||
Ok(ys)
|
||||
}
|
||||
Self::Mlp(mlp) => mlp.forward(xs),
|
||||
Self::Silu(mlp) => mlp.forward(xs),
|
||||
Self::Simple(mlp) => mlp.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum Norm {
|
||||
Rms(crate::quantized_nn::RmsNorm),
|
||||
Layer(candle_nn::LayerNorm),
|
||||
}
|
||||
|
||||
impl Module for Norm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Rms(m) => m.forward(xs),
|
||||
Self::Layer(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rms_norm(q: QTensor, eps: f64) -> Result<Norm> {
|
||||
let rms = crate::quantized_nn::RmsNorm::from_qtensor(q, eps)?;
|
||||
Ok(Norm::Rms(rms))
|
||||
}
|
||||
|
||||
fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<Norm> {
|
||||
let w = w.dequantize(&w.device())?;
|
||||
let b = b.dequantize(&b.device())?;
|
||||
let ln = candle_nn::LayerNorm::new(w, b, eps);
|
||||
Ok(Norm::Layer(ln))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
attention_wv: QMatMul,
|
||||
attention_wo: QMatMul,
|
||||
attention_norm: RmsNorm,
|
||||
mlp_or_moe: MlpOrMoe,
|
||||
ffn_norm: RmsNorm,
|
||||
attention_norm: Norm,
|
||||
mlp: Mlp,
|
||||
ffn_norm: Norm,
|
||||
n_head: usize,
|
||||
n_kv_head: usize,
|
||||
head_dim: usize,
|
||||
@ -230,7 +272,7 @@ impl LayerWeights {
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
norm: Norm,
|
||||
output: QMatMul,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
span: tracing::Span,
|
||||
@ -256,6 +298,99 @@ fn precomput_freqs_cis(
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum Architecture {
|
||||
Llama,
|
||||
Phi2,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MetadataConfig {
|
||||
n_expert: usize,
|
||||
n_expert_used: usize,
|
||||
head_count: usize,
|
||||
head_count_kv: usize,
|
||||
block_count: usize,
|
||||
embedding_length: usize,
|
||||
rope_dim: usize,
|
||||
rms_norm_eps: f64,
|
||||
rope_freq_base: f32,
|
||||
architecture: Architecture,
|
||||
}
|
||||
|
||||
impl MetadataConfig {
|
||||
fn from_gguf(ct: &gguf_file::Content) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
let architecture = match md_get("general.architecture")
|
||||
.and_then(|v| v.to_string())
|
||||
.map(|v| v.as_str())
|
||||
{
|
||||
Ok("phi2") => Architecture::Phi2,
|
||||
Err(_) | Ok(_) => Architecture::Llama,
|
||||
};
|
||||
|
||||
let config = match architecture {
|
||||
Architecture::Phi2 => {
|
||||
let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("phi2.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
|
||||
let rms_norm_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
|
||||
Self {
|
||||
n_expert: 1,
|
||||
n_expert_used: 1,
|
||||
head_count,
|
||||
head_count_kv,
|
||||
block_count,
|
||||
embedding_length,
|
||||
rope_freq_base: 10_000.,
|
||||
rope_dim,
|
||||
rms_norm_eps,
|
||||
architecture,
|
||||
}
|
||||
}
|
||||
Architecture::Llama => {
|
||||
let n_expert = md_get("llama.expert_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let n_expert_used = md_get("llama.expert_used_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
|
||||
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
|
||||
let rms_norm_eps =
|
||||
md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
Self {
|
||||
n_expert,
|
||||
n_expert_used,
|
||||
head_count,
|
||||
head_count_kv,
|
||||
block_count,
|
||||
embedding_length,
|
||||
rope_freq_base,
|
||||
rope_dim,
|
||||
rms_norm_eps,
|
||||
architecture,
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
@ -263,7 +398,7 @@ impl ModelWeights {
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
||||
let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
|
||||
let norm = rms_norm(ct.remove("norm.weight")?, 1e-5)?;
|
||||
let output = ct.remove("output.weight")?;
|
||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||
for layer_idx in 0..ct.hparams.n_layer {
|
||||
@ -272,11 +407,11 @@ impl ModelWeights {
|
||||
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
|
||||
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
|
||||
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
|
||||
let mlp_or_moe = {
|
||||
let mlp = {
|
||||
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
|
||||
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
|
||||
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
Mlp::Silu(MlpSilu {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
@ -292,9 +427,9 @@ impl ModelWeights {
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
|
||||
mlp_or_moe,
|
||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
|
||||
attention_norm: rms_norm(attention_norm, 1e-5)?,
|
||||
mlp,
|
||||
ffn_norm: rms_norm(ffn_norm, 1e-5)?,
|
||||
n_head: ct.hparams.n_head as usize,
|
||||
n_kv_head: ct.hparams.n_head as usize / gqa,
|
||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||
@ -325,78 +460,71 @@ impl ModelWeights {
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
let cfg = MetadataConfig::from_gguf(&ct)?;
|
||||
|
||||
// Parameter extraction from metadata.
|
||||
let n_expert = md_get("llama.expert_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let n_expert_used = md_get("llama.expert_used_count")
|
||||
.and_then(|v| v.to_u32())
|
||||
.unwrap_or(0) as usize;
|
||||
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
|
||||
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
|
||||
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||
let (cos, sin) = precomput_freqs_cis(cfg.rope_dim, cfg.rope_freq_base, device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::from_qtensor(
|
||||
let norm = rms_norm(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
cfg.rms_norm_eps,
|
||||
)?;
|
||||
let output = ct.tensor(reader, "output.weight", device)?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let mut layers = Vec::with_capacity(cfg.block_count);
|
||||
for layer_idx in 0..cfg.block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
let mlp_or_moe = if n_expert <= 1 {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
MlpOrMoe::Mlp(Mlp {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
let mlp = if cfg.n_expert <= 1 {
|
||||
match cfg.architecture {
|
||||
Architecture::Llama => {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
Mlp::Silu(MlpSilu {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
}
|
||||
Architecture::Phi2 => {
|
||||
let fc1 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
let fc2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
Mlp::Simple(MlpSimple {
|
||||
fc1: QMatMul::from_qtensor(fc1)?,
|
||||
fc2: QMatMul::from_qtensor(fc2)?,
|
||||
act: candle_nn::Activation::NewGelu,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let feed_forward_gate_inp =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
|
||||
let mut experts = Vec::with_capacity(n_expert);
|
||||
for i in 0..n_expert {
|
||||
let mut experts = Vec::with_capacity(cfg.n_expert);
|
||||
for i in 0..cfg.n_expert {
|
||||
let feed_forward_w1 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?;
|
||||
let feed_forward_w2 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?;
|
||||
let feed_forward_w3 =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?;
|
||||
experts.push(Mlp {
|
||||
experts.push(MlpSilu {
|
||||
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
|
||||
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
|
||||
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
|
||||
})
|
||||
}
|
||||
MlpOrMoe::MoE {
|
||||
n_expert_used,
|
||||
Mlp::MoE {
|
||||
n_expert_used: cfg.n_expert_used,
|
||||
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
|
||||
experts,
|
||||
}
|
||||
@ -412,12 +540,12 @@ impl ModelWeights {
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?,
|
||||
mlp_or_moe,
|
||||
ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: embedding_length / head_count,
|
||||
attention_norm: rms_norm(attention_norm, cfg.rms_norm_eps)?,
|
||||
mlp,
|
||||
ffn_norm: rms_norm(ffn_norm, cfg.rms_norm_eps)?,
|
||||
n_head: cfg.head_count,
|
||||
n_kv_head: cfg.head_count_kv,
|
||||
head_dim: cfg.embedding_length / cfg.head_count,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
@ -430,7 +558,7 @@ impl ModelWeights {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
tok_embeddings: Embedding::new(tok_embeddings, cfg.embedding_length),
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
@ -473,7 +601,7 @@ impl ModelWeights {
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let x = layer.mlp_or_moe.forward(&x)?;
|
||||
let x = layer.mlp.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
layer_in = x
|
||||
}
|
||||
|
@ -1,288 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QLinear {
|
||||
inner: candle::quantized::QMatMul,
|
||||
bias: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl QLinear {
|
||||
fn new<R: std::io::Read + std::io::Seek>(
|
||||
ct: &gguf_file::Content,
|
||||
r: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
let w = ct.tensor(r, &format!("{name}.weight"), device)?;
|
||||
let b = ct.tensor(r, &format!("{name}.bias"), device)?;
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(w)?;
|
||||
let bias = b.dequantize(device)?;
|
||||
Ok(Self { inner, bias, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for QLinear {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)?.broadcast_add(&self.bias)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
ffn_up: QLinear,
|
||||
ffn_down: QLinear,
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.ffn_up)?.gelu()?.apply(&self.ffn_down)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
attn_qkv: QLinear,
|
||||
attn_output: QLinear,
|
||||
attn_norm: LayerNorm,
|
||||
mlp: Mlp,
|
||||
n_head: usize,
|
||||
n_kv_head: usize,
|
||||
head_dim: usize,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
rope_dim: usize,
|
||||
neg_inf: Tensor,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span_attn: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl LayerWeights {
|
||||
fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (_b_sz, _n_head, seq_len, _n_embd) = xs.dims4()?;
|
||||
let xs_rot = xs.i((.., .., .., ..self.rope_dim))?;
|
||||
let xs_pass = xs.i((.., .., .., self.rope_dim..))?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot.contiguous()?, &cos, &sin)?;
|
||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||
}
|
||||
|
||||
fn forward_attn(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
index_pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let qkv =
|
||||
self.attn_qkv
|
||||
.forward(x)?
|
||||
.reshape((b_sz, seq_len, 3, self.n_head, self.head_dim))?;
|
||||
|
||||
let q = qkv.i((.., .., 0))?.transpose(1, 2)?;
|
||||
let k = qkv.i((.., .., 1))?.transpose(1, 2)?;
|
||||
let v = qkv.i((.., .., 2))?.transpose(1, 2)?;
|
||||
// This call to contiguous ensures that the fast kernel can be called below. It's
|
||||
// actually a no-op except when processing the initial prompt so has no significant
|
||||
// impact on performance.
|
||||
let v = v.contiguous()?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
|
||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k.contiguous()?, v.contiguous()?),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = match mask {
|
||||
None => att,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, &self.neg_inf)?
|
||||
}
|
||||
};
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.attn_output.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
layers: Vec<LayerWeights>,
|
||||
output_norm: LayerNorm,
|
||||
output: QLinear,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<LayerNorm> {
|
||||
let w = w.dequantize(&w.device())?;
|
||||
let b = b.dequantize(&b.device())?;
|
||||
let ln = LayerNorm::new(w, b, eps);
|
||||
Ok(ln)
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
// Parameter extraction from metadata.
|
||||
let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("phi2.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
|
||||
let ln_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let output_norm = layer_norm(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
ct.tensor(reader, "output_norm.bias", device)?,
|
||||
ln_eps,
|
||||
)?;
|
||||
let output = QLinear::new(&ct, reader, "output", device)?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?;
|
||||
let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?;
|
||||
let mlp = Mlp { ffn_up, ffn_down };
|
||||
let attn_norm = layer_norm(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.bias"), device)?,
|
||||
ln_eps,
|
||||
)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
layers.push(LayerWeights {
|
||||
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
||||
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
||||
attn_norm,
|
||||
mlp,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: embedding_length / head_count,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
rope_dim,
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
})
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
layers,
|
||||
output_norm,
|
||||
output,
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = xs.dims2()?;
|
||||
let mask = if seq_len == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.mask(seq_len, xs.device())?)
|
||||
};
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.tok_embeddings.forward(xs)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
let residual = &xs;
|
||||
let xs_norm = xs.apply(&layer.attn_norm)?;
|
||||
let attn_outputs = layer.forward_attn(&xs_norm, mask.as_ref(), index_pos)?;
|
||||
let feed_forward_hidden_states = layer.mlp.forward(&xs_norm)?;
|
||||
xs = (attn_outputs + feed_forward_hidden_states + residual)?
|
||||
}
|
||||
let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
|
||||
let _enter = self.span_output.enter();
|
||||
self.output.forward(&xs)
|
||||
}
|
||||
}
|
@ -1,301 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, RmsNorm};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QLinear {
|
||||
inner: candle::quantized::QMatMul,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl QLinear {
|
||||
fn new<R: std::io::Read + std::io::Seek>(
|
||||
ct: &gguf_file::Content,
|
||||
r: &mut R,
|
||||
name: &str,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
let w = ct.tensor(r, &format!("{name}.weight"), device)?;
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(w)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for QLinear {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
ffn_up: QLinear,
|
||||
ffn_down: QLinear,
|
||||
i_size: usize,
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let up_states = xs.apply(&self.ffn_up)?;
|
||||
let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
|
||||
let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
|
||||
let up_states = (up_states * gate.silu()?)?;
|
||||
up_states.apply(&self.ffn_down)
|
||||
}
|
||||
}
|
||||
|
||||
fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
|
||||
let w = w.dequantize(&w.device())?;
|
||||
let rms = RmsNorm::new(w, eps);
|
||||
Ok(rms)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
attn_qkv: QLinear,
|
||||
attn_output: QLinear,
|
||||
attn_norm: RmsNorm,
|
||||
ffn_norm: RmsNorm,
|
||||
mlp: Mlp,
|
||||
n_head: usize,
|
||||
n_kv_head: usize,
|
||||
head_dim: usize,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
neg_inf: Tensor,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span_attn: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
impl LayerWeights {
|
||||
fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
candle_nn::rotary_emb::rope(&xs.contiguous()?, &cos, &sin)
|
||||
}
|
||||
|
||||
fn forward_attn(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
index_pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let qkv = self.attn_qkv.forward(x)?;
|
||||
|
||||
let query_pos = self.n_head * self.head_dim;
|
||||
let q = qkv.narrow(D::Minus1, 0, query_pos)?;
|
||||
let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
|
||||
let v = qkv.narrow(
|
||||
D::Minus1,
|
||||
query_pos + self.n_kv_head * self.head_dim,
|
||||
self.n_kv_head * self.head_dim,
|
||||
)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
|
||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k.contiguous()?, v.contiguous()?),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = match mask {
|
||||
None => att,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, &self.neg_inf)?
|
||||
}
|
||||
};
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.attn_output.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
layers: Vec<LayerWeights>,
|
||||
output_norm: RmsNorm,
|
||||
output: QLinear,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
// Parameter extraction from metadata.
|
||||
let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
|
||||
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
|
||||
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
|
||||
let output = QLinear::new(&ct, reader, "output", device)?;
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?;
|
||||
let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?;
|
||||
let mlp = Mlp {
|
||||
ffn_up,
|
||||
ffn_down,
|
||||
i_size,
|
||||
};
|
||||
let attn_norm = rms_norm(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
|
||||
rms_eps,
|
||||
)?;
|
||||
let ffn_norm = rms_norm(
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
|
||||
rms_eps,
|
||||
)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
layers.push(LayerWeights {
|
||||
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
||||
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
||||
attn_norm,
|
||||
ffn_norm,
|
||||
mlp,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: embedding_length / head_count,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
})
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
layers,
|
||||
output_norm,
|
||||
output,
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = xs.dims2()?;
|
||||
let mask = if seq_len == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.mask(seq_len, xs.device())?)
|
||||
};
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = self.tok_embeddings.forward(xs)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
let residual = &xs;
|
||||
let ys = xs.apply(&layer.attn_norm)?;
|
||||
let ys = layer.forward_attn(&ys, mask.as_ref(), index_pos)?;
|
||||
let ys = (ys + residual)?;
|
||||
let residual = &ys;
|
||||
let ys = ys.apply(&layer.ffn_norm)?;
|
||||
let ys = layer.mlp.forward(&ys)?;
|
||||
xs = (ys + residual)?
|
||||
}
|
||||
let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
|
||||
let _enter = self.span_output.enter();
|
||||
self.output.forward(&xs)
|
||||
}
|
||||
}
|
@ -27,6 +27,13 @@ struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
@ -41,6 +48,7 @@ impl RotaryEmbedding {
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
@ -56,8 +64,10 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
@ -33,6 +33,13 @@ struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||
@ -47,6 +54,7 @@ impl RotaryEmbedding {
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
@ -62,8 +70,10 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
@ -249,28 +259,30 @@ impl Module for SparseMoeBlock {
|
||||
|
||||
// In order to extract topk, we extract the data from the tensor and manipulate it
|
||||
// directly. Maybe we will want to use some custom ops instead at some point.
|
||||
let experts_per_tok = routing_weights
|
||||
.arg_sort_last_dim(false)?
|
||||
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
|
||||
.contiguous()?;
|
||||
let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
|
||||
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
// top_x contains the row indexes to evaluate for each expert.
|
||||
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||
let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
|
||||
let mut top_x = vec![vec![]; self.experts.len()];
|
||||
let mut selected_experts = vec![vec![]; self.experts.len()];
|
||||
for (row_idx, (rw, expert_idxs)) in routing_weights
|
||||
.iter()
|
||||
.zip(experts_per_tok.iter())
|
||||
.enumerate()
|
||||
{
|
||||
let sum_rw = rw.iter().sum::<f32>();
|
||||
for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
|
||||
top_x[expert_idx as usize].push(row_idx as u32);
|
||||
let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
|
||||
selected_experts[expert_idx as usize].push(rw)
|
||||
for (row_idx, rw) in routing_weights.iter().enumerate() {
|
||||
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
|
||||
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
|
||||
let mut sum_routing_weights = 0f32;
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = rw[expert_idx];
|
||||
sum_routing_weights += routing_weight;
|
||||
top_x[expert_idx].push(row_idx as u32);
|
||||
}
|
||||
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
|
||||
let expert_idx = expert_idx as usize;
|
||||
let routing_weight = if self.norm_topk_prob {
|
||||
rw[expert_idx] / sum_routing_weights
|
||||
} else {
|
||||
rw[expert_idx]
|
||||
};
|
||||
selected_experts[expert_idx].push(routing_weight)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -180,11 +180,6 @@ impl RmsNorm {
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn forward_diff(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward_diff(x)
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
|
@ -1,5 +1,26 @@
|
||||
pub const WITH_TIMER: bool = true;
|
||||
|
||||
struct Timer {
|
||||
label: &'static str,
|
||||
}
|
||||
|
||||
// impl Timer {
|
||||
// fn new(label: &'static str) -> Self {
|
||||
// if WITH_TIMER {
|
||||
// web_sys::console::time_with_label(label);
|
||||
// }
|
||||
// Self { label }
|
||||
// }
|
||||
// }
|
||||
|
||||
impl Drop for Timer {
|
||||
fn drop(&mut self) {
|
||||
if WITH_TIMER {
|
||||
web_sys::console::time_end_with_label(self.label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod app;
|
||||
mod audio;
|
||||
pub mod languages;
|
||||
|
@ -55,7 +55,7 @@ fn quantized_matmul_neg() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates a vector similarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
|
||||
fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
|
||||
const GGML_TEST_SIZE: usize = 32 * 128;
|
||||
(0..GGML_TEST_SIZE)
|
||||
|
Reference in New Issue
Block a user