Compare commits

...

15 Commits

Author SHA1 Message Date
7ff921c538 Add RandomNormal ONNX operator (#2200) 2024-05-21 21:47:32 +02:00
9b8537a62f Remove the deprecated wav crate in favor of hound. (#2202) 2024-05-21 21:43:35 +02:00
7ebc3548e1 Use flash-attn in gemma. (#2195)
* Use flash-attn in gemma.

* Fix flash-attn for head dim 256.
2024-05-18 19:18:59 +02:00
eefc1c77ef Support flash-attn in quantized phi3. (#2194) 2024-05-18 17:12:56 +02:00
01545f7303 Add a slice_set op. (#2193)
* Add a slice_set op.

* Add some testing.

* Add the dedicated kv-cache module.

* Derive debug and clone.

* Expose more kv-cache functions.

* Return the current data when appending.

* Use the new cache in the quantized phi3 model.
2024-05-18 15:58:18 +02:00
349c3e806a Support embedding model gte-Qwen1.5-7B-instruct (#2190)
* Support embedding model gte-Qwen1.5-7B-instruct

This is a text embedding model based on Qwen2. They share same
model architecture except the last MLP module. This commit brings in
minimal modification of the old Qwen2 implementation to support both
models.

An example is provided, and had been verified according to the official
PyTorch implementation.

* Avoid doing the 'last-token filtering' based on the absence of attention mask.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2024-05-16 21:34:10 +02:00
bdaa34216a chore: add fix for windows cudarc into the readme (#2189) 2024-05-16 14:32:50 +02:00
cc80e065e5 Allow the threshold argumet to be negative in the segment-anything example (#2187)
Threshold is 0.0 by default, negative values make more points included,
expanding the mask. Positive values make it more picky, making the mask
smaller.

Negative numbers start with a minus sign, which normally makes clap
consider it a flag.
2024-05-15 13:17:20 +02:00
13c64f6828 Fix VarBuilder::from_slice_safetensors (#2180)
Also implement SimpleBackend for SliceSafetensors

Signed-off-by: Harry Stern <harry@harrystern.net>
2024-05-12 07:26:06 +02:00
21f82a5155 Add SliceSafetensors. (#2179)
* Add SlicedSafetensors.

* And add some testing.
2024-05-11 13:15:42 +02:00
9cff7bc3f4 Make it possible to use TF32 accumulation in F32 matmuls. (#2178)
* Allow the use of tf32 accumulation in matmul.

* Better timings.

* Dummy versions for use when cuda is not enabled.
2024-05-11 12:28:39 +02:00
d9bc5ec151 Switch cudarc back to dynamic linking. (#2176) 2024-05-09 10:35:44 +02:00
84328e2b60 Update cudarc requirement from 0.11.0 to 0.11.1 (#2174)
* Upgrading cudarc dependency from v0.11.0 to v0.11.1 due to that version having resolved a compile-time bug.

See: https://github.com/huggingface/candle/issues/2173
2024-05-08 20:40:36 +02:00
82b641fd27 Update cudarc requirement from 0.10.0 to 0.11.0 (#2165)
* Update cudarc requirement from 0.10.0 to 0.11.0

Updates the requirements on [cudarc](https://github.com/coreylowman/cudarc) to permit the latest version.
- [Release notes](https://github.com/coreylowman/cudarc/releases)
- [Commits](https://github.com/coreylowman/cudarc/compare/v0.10.0...v0.10.0)

---
updated-dependencies:
- dependency-name: cudarc
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* Use the default cuda version.

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-05-06 17:12:14 +02:00
01794dc16e Use write rather than try-write on the metal rw-locks. (#2162) 2024-05-05 07:22:46 +02:00
30 changed files with 959 additions and 128 deletions

View File

@ -43,11 +43,12 @@ candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] }
cudarc = { version = "0.11.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
@ -69,7 +70,6 @@ tokenizers = { version = "0.19.1", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -408,3 +408,10 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
error is generated.
#### CudaRC error
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`

View File

@ -37,7 +37,6 @@ tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
parquet = { workspace = true }
image = { workspace = true }

View File

@ -5,32 +5,26 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Module, Tensor};
use candle_core::quantized::{QMatMul, QTensor};
use candle_core::{Device, Tensor};
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
let q_cpu = q.to_device(&Device::Cpu)?;
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
let q = QMatMul::from_qtensor(q)?;
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
let res_q_cuda = q.forward(&x)?;
println!("{res_q_cuda}");
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
let x_cpu = x.to_device(&Device::Cpu)?;
let res_q_cpu = q_cpu.forward(&x_cpu)?;
println!("{res_q_cpu}");
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
.abs()?
.flatten_all()?
.max(0)?;
println!("{diff}");
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
drop(_x1);
Ok(())
}

View File

@ -1615,12 +1615,8 @@ impl BackendStorage for CudaStorage {
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
unsafe {
self.device
.blas
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
}
.w()?;
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::F32(out)
}
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
@ -1817,6 +1813,20 @@ static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(b: bool) {
MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with f16 GEMMs.
@ -1842,6 +1852,51 @@ pub fn set_gemm_reduced_precision_bf16(b: bool) {
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
}
unsafe fn gemm_strided_batched_f32(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<f32>,
a: &cudarc::driver::CudaView<f32>,
b: &cudarc::driver::CudaView<f32>,
c: &mut CudaSlice<f32>,
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
use cudarc::cublas::sys;
use cudarc::driver::DevicePtrMut;
let compute_type = if gemm_reduced_precision_f32() {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
} else {
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
};
let alpha = &cfg.gemm.alpha as *const f32 as *const _;
let beta = &cfg.gemm.beta as *const f32 as *const _;
cudarc::cublas::result::gemm_strided_batched_ex(
*cublas.handle(),
cfg.gemm.transa,
cfg.gemm.transb,
cfg.gemm.m,
cfg.gemm.n,
cfg.gemm.k,
alpha,
*a.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.lda,
cfg.stride_a,
*b.device_ptr() as *const _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.ldb,
cfg.stride_b,
beta,
*c.device_ptr_mut() as *mut _,
sys::cudaDataType_t::CUDA_R_32F,
cfg.gemm.ldc,
cfg.stride_c,
cfg.batch_size,
compute_type,
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
)
}
unsafe fn gemm_strided_batched_f16(
cublas: &cudarc::cublas::CudaBlas,
cfg: StridedBatchedConfig<f16>,

View File

@ -258,3 +258,13 @@ pub fn gemm_reduced_precision_bf16() -> bool {
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
/// allowed with bf16 GEMMs.
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn gemm_reduced_precision_f32() -> bool {
true
}
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
/// allowed with f32 GEMMs.
pub fn set_gemm_reduced_precision_f32(_b: bool) {}

View File

@ -100,11 +100,11 @@ impl MetalDevice {
}
pub fn command_buffer(&self) -> Result<CommandBuffer> {
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
let mut command_buffer = command_buffer_lock.to_owned();
let mut index = self
.command_buffer_index
.try_write()
.write()
.map_err(MetalError::from)?;
if *index > self.compute_per_buffer {
command_buffer.commit();
@ -119,7 +119,7 @@ impl MetalDevice {
}
pub fn wait_until_completed(&self) -> Result<()> {
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
match command_buffer.status() {
metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled
@ -179,7 +179,7 @@ impl MetalDevice {
size,
MTLResourceOptions::StorageModeManaged,
);
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
let subbuffers = buffers
.entry((size, MTLResourceOptions::StorageModeManaged))
.or_insert(vec![]);
@ -232,7 +232,7 @@ impl MetalDevice {
}
fn drop_unused_buffers(&self) -> Result<()> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
for subbuffers in buffers.values_mut() {
let newbuffers = subbuffers
.iter()
@ -251,7 +251,7 @@ impl MetalDevice {
option: MTLResourceOptions,
_name: &str,
) -> Result<Arc<Buffer>> {
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
// Cloning also ensures we increment the strong count
return Ok(b.clone());

View File

@ -6,7 +6,7 @@ use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
use metal::{Buffer, MTLResourceOptions, NSUInteger};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::{Arc, Mutex, RwLock, TryLockError};
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
@ -36,6 +36,12 @@ impl<T> From<TryLockError<T>> for MetalError {
}
}
impl<T> From<PoisonError<T>> for MetalError {
fn from(p: PoisonError<T>) -> Self {
MetalError::LockError(LockError::Poisoned(p.to_string()))
}
}
/// Metal related errors
#[derive(thiserror::Error, Debug)]
pub enum MetalError {

View File

@ -349,6 +349,30 @@ impl MmapedSafetensors {
}
}
pub struct SliceSafetensors<'a> {
safetensors: SafeTensors<'a>,
}
impl<'a> SliceSafetensors<'a> {
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
pub fn new(buffer: &'a [u8]) -> Result<Self> {
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
Ok(Self { safetensors })
}
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
self.safetensors.tensor(name)?.load(dev)
}
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
self.safetensors.tensors()
}
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
Ok(self.safetensors.tensor(name)?)
}
}
pub struct BufferedSafetensors {
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
}

View File

@ -235,4 +235,66 @@ impl Tensor {
}
Ok(crate::tensor::from_storage(storage, shape, op, false))
}
/// Set the values on `self` using values from `src`. The copy starts at the specified
/// `offset` for the target dimension `dim` on `self`.
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
/// has to be greater than or equal to `offset` plus the `src` size.
///
/// Note that this modifies `self` in place and as such is not compatibel with
/// back-propagation.
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
let dim = dim.to_index(self.shape(), "slice-set")?;
if !self.is_contiguous() || !src.is_contiguous() {
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
}
if self.dtype() != src.dtype() {
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: src.dtype(),
op: "slice-set",
}
.bt())?
}
if self.device().location() != src.device().location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: self.device().location(),
rhs: src.device().location(),
op: "slice-set",
}
.bt())?
}
if self.rank() != src.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: self.rank(),
got: src.rank(),
shape: self.shape().clone(),
}
.bt())?
}
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
if dim_idx == dim && *v2 + offset > *v1 {
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
}
if dim_idx != dim && v1 != v2 {
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
}
}
let block_size: usize = src.dims().iter().skip(1 + dim).product();
let d1: usize = src.dims().iter().take(dim).product();
let d2 = block_size * src.dims()[dim];
let dst_o = self.layout().start_offset() + offset * block_size;
let src_o = src.layout().start_offset();
src.storage().copy2d(
&mut self.storage_mut(),
d1,
d2,
/* src_s */ d2,
/* dst_s */ block_size * self.dims()[dim],
src_o,
dst_o,
)?;
Ok(())
}
}

View File

@ -1,5 +1,31 @@
use candle_core::{DType, Result, Tensor};
struct TmpFile(std::path::PathBuf);
impl TmpFile {
fn create(base: &str) -> TmpFile {
let filename = std::env::temp_dir().join(format!(
"candle-{}-{}-{:?}",
base,
std::process::id(),
std::thread::current().id(),
));
TmpFile(filename)
}
}
impl std::convert::AsRef<std::path::Path> for TmpFile {
fn as_ref(&self) -> &std::path::Path {
self.0.as_path()
}
}
impl Drop for TmpFile {
fn drop(&mut self) {
std::fs::remove_file(&self.0).unwrap()
}
}
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
@ -22,3 +48,24 @@ fn npz() -> Result<()> {
);
Ok(())
}
#[test]
fn safetensors() -> Result<()> {
use candle_core::safetensors::Load;
let tmp_file = TmpFile::create("st");
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
t.save_safetensors("t", &tmp_file)?;
// Load from file.
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
let t2 = st.get("t").unwrap();
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
// Load from bytes.
let bytes = std::fs::read(tmp_file)?;
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0f32);
Ok(())
}

View File

@ -665,6 +665,30 @@ fn broadcast(device: &Device) -> Result<()> {
Ok(())
}
fn slice_set(device: &Device) -> Result<()> {
let (b, h, max_t, d) = (2, 4, 7, 3);
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
cache.slice_set(&tensor, 2, 0)?;
let cache_t = cache.narrow(2, 0, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
cache.slice_set(&tensor, 2, 1)?;
let cache_t = cache.narrow(2, 1, 4)?;
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
cache.slice_set(&ones, 2, 6)?;
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
let diff = (cache.narrow(2, 6, 1)? - 1.)?
.abs()?
.sum_all()?
.to_vec0::<f32>()?;
assert_eq!(diff, 0.);
Ok(())
}
fn cat(device: &Device) -> Result<()> {
// 1D
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
@ -1146,6 +1170,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
test_device!(min, min_cpu, min_gpu, min_metal);

View File

@ -193,6 +193,9 @@ struct Args {
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
fn main() -> Result<()> {
@ -270,7 +273,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = Model::new(args.use_flash_attn, &config, vb)?;
println!("loaded the model in {:?}", start.elapsed());

View File

@ -0,0 +1,19 @@
# gte-Qwen1.5-7B-instruct
gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.
- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
## Running the example
Automatically download the model from the HuggingFace hub:
```bash
$ cargo run --example gte-qwen --release
```
or, load the model from a local directory:
```bash
cargo run --example gte-qwen --release --features cuda -- --local-repo /path/to/gte_Qwen1.5-7B-instruct/
```

View File

@ -0,0 +1,178 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config, Model};
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
Tokenizer,
};
// gte-Qwen1.5-7B-instruct use EOS token as padding token
const EOS_TOKEN: &str = "<|endoftext|>";
const EOS_TOKEN_ID: u32 = 151643;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "Alibaba-NLP/gte-Qwen1.5-7B-instruct")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
local_repo: Option<String>,
}
#[derive(Debug)]
struct ConfigFiles {
pub config: std::path::PathBuf,
pub tokenizer: std::path::PathBuf,
pub weights: Vec<std::path::PathBuf>,
}
// Loading the model from the HuggingFace Hub. Network access is required.
fn load_from_hub(model_id: &str, revision: &str) -> Result<ConfigFiles> {
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
Ok(ConfigFiles {
config: repo.get("config.json")?,
tokenizer: repo.get("tokenizer.json")?,
weights: candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
})
}
// Loading the model from a local directory.
fn load_from_local(local_path: &str) -> Result<ConfigFiles> {
let local_path = std::path::PathBuf::from(local_path);
let weight_path = local_path.join("model.safetensors.index.json");
let json: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(weight_path)?)?;
let weight_map = match json.get("weight_map") {
Some(serde_json::Value::Object(map)) => map,
Some(_) => panic!("`weight map` is not a map"),
None => panic!("`weight map` not found"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
safetensors_files.insert(
value
.as_str()
.expect("Weight files should be parsed as strings"),
);
}
let safetensors_paths = safetensors_files
.iter()
.map(|v| local_path.join(v))
.collect::<Vec<_>>();
Ok(ConfigFiles {
config: local_path.join("config.json"),
tokenizer: local_path.join("tokenizer.json"),
weights: safetensors_paths,
})
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
// Fetch the model. Do this offline if local path provided.
println!("Fetching model files...");
let start = std::time::Instant::now();
let config_files = match args.local_repo {
Some(local_path) => load_from_local(&local_path)?,
None => load_from_hub(&args.model_id, &args.revision)?,
};
println!("Model file retrieved in {:?}", start.elapsed());
// Inputs will be padded to the longest sequence in the batch.
let padding = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_to_multiple_of: None,
pad_id: EOS_TOKEN_ID,
pad_type_id: 0,
pad_token: String::from(EOS_TOKEN),
};
// Tokenizer setup
let mut tokenizer = Tokenizer::from_file(config_files.tokenizer).map_err(E::msg)?;
tokenizer.with_padding(Some(padding));
// Model initialization
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let config: Config = serde_json::from_slice(&std::fs::read(config_files.config)?)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&config_files.weights, dtype, &device)? };
let mut model = Model::new(&config, vb)?;
println!("Model loaded in {:?}", start.elapsed());
// Encode the queries and the targets
let instruct = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ";
let documents = vec![
format!("{instruct}how much protein should a female eat{EOS_TOKEN}"),
format!("{instruct}summit define{EOS_TOKEN}"),
format!("As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.{EOS_TOKEN}"),
format!("Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.{EOS_TOKEN}"),
];
let encoded = tokenizer.encode_batch(documents, true).map_err(E::msg)?;
let tokens: Vec<&[u32]> = encoded.iter().map(|x| x.get_ids()).collect();
let tokens = Tensor::new(tokens, &device)?;
let mask: Vec<&[u32]> = encoded.iter().map(|x| x.get_attention_mask()).collect();
let mask = Tensor::new(mask, &device)?;
// Inference
let start_gen = std::time::Instant::now();
let logits = model.forward(&tokens, 0, Some(&mask))?;
// Extract the last hidden states as embeddings since inputs are padded left.
let (_, seq_len, _) = logits.dims3()?;
let embd = logits
.narrow(1, seq_len - 1, 1)?
.squeeze(1)?
.to_dtype(DType::F32)?;
// Calculate the relativity scores. Note the embeddings should be normalized.
let norm = embd.broadcast_div(&embd.sqr()?.sum_keepdim(1)?.sqrt()?)?;
let scores = norm.narrow(0, 0, 2)?.matmul(&norm.narrow(0, 2, 2)?.t()?)?;
// Print the results
println!("Embedding done in {:?}", start_gen.elapsed());
println!("Scores: {:?}", scores.to_vec2::<f32>()?);
Ok(())
}

View File

@ -90,6 +90,9 @@ struct Args {
/// The model size to use.
#[arg(long, default_value = "phi-3b")]
which: Which,
#[arg(long)]
use_flash_attn: bool,
}
impl Args {
@ -213,7 +216,13 @@ fn main() -> anyhow::Result<()> {
);
match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
1,
args.use_flash_attn,
model,
&mut file,
&device,
)?),
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
}
};

View File

@ -7,7 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
use candle::{DType, Device, Tensor};

View File

@ -39,7 +39,7 @@ struct Args {
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
/// mask, positive makes the mask more selective.
#[arg(long, default_value_t = 0.)]
#[arg(long, allow_hyphen_values = true, default_value_t = 0.)]
threshold: f32,
/// Enable tracing (generates a trace-timestamp.json file).

View File

@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);

View File

@ -139,7 +139,9 @@ impl FlashAttn {
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
.w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };

101
candle-nn/src/kv_cache.rs Normal file
View File

@ -0,0 +1,101 @@
use candle::{DType, Device, Result, Shape, Tensor};
#[derive(Debug, Clone)]
pub struct Cache {
all_data: Tensor,
dim: usize,
current_seq_len: usize,
max_seq_len: usize,
}
impl Cache {
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
dim: D,
shape: S,
dtype: DType,
dev: &Device,
) -> Result<Self> {
let shape = shape.into();
let dim = dim.to_index(&shape, "kv-cache")?;
let max_seq_len = shape.dims()[dim];
let all_data = Tensor::zeros(shape, dtype, dev)?;
Ok(Self {
all_data,
dim,
current_seq_len: 0,
max_seq_len,
})
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn all_data(&self) -> &Tensor {
&self.all_data
}
pub fn current_data(&self) -> Result<Tensor> {
self.all_data.narrow(self.dim, 0, self.current_seq_len)
}
pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
if self.current_seq_len + seq_len > self.max_seq_len {
candle::bail!(
"kv-cache: above max-seq-len {}+{seq_len}>{}",
self.current_seq_len,
self.max_seq_len
)
}
self.all_data
.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct KvCache {
k: Cache,
v: Cache,
}
impl KvCache {
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
dim: D,
shape: S,
dtype: DType,
dev: &Device,
) -> Result<Self> {
let shape = shape.into();
let dim = dim.to_index(&shape, "kv-cache")?;
let k = Cache::new(dim, &shape, dtype, dev)?;
let v = Cache::new(dim, &shape, dtype, dev)?;
Ok(Self { k, v })
}
pub fn k(&self) -> Result<Tensor> {
self.k.current_data()
}
pub fn v(&self) -> Result<Tensor> {
self.v.current_data()
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.k.append(k)?;
self.v.append(v)?;
let k = self.k.current_data()?;
let v = self.v.current_data()?;
Ok((k, v))
}
}

View File

@ -6,6 +6,7 @@ pub mod encoding;
pub mod func;
pub mod group_norm;
pub mod init;
pub mod kv_cache;
pub mod layer_norm;
pub mod linear;
pub mod loss;

View File

@ -422,6 +422,32 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
}
}
impl<'a> SimpleBackend for candle::safetensors::SliceSafetensors<'a> {
fn get(
&self,
s: Shape,
name: &str,
_: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
}
.bt())?
}
Ok(tensor)
}
fn contains_tensor(&self, name: &str) -> bool {
self.get(name).is_ok()
}
}
impl<'a> VarBuilder<'a> {
/// Initializes a `VarBuilder` using a custom backend.
///
@ -481,12 +507,18 @@ impl<'a> VarBuilder<'a> {
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
/// Initializes a `VarBuilder` from a binary buffer in the safetensor format.
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}
/// Initializes a `VarBuilder` from a binary slice in the safetensor format.
pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
let tensors = candle::safetensors::SliceSafetensors::new(data)?;
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
}
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let npz = candle::npy::NpzTensors::new(p)?;

View File

@ -971,7 +971,7 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
"RandomUniform" => {
random_type @ ("RandomUniform" | "RandomNormal") => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
// default
@ -979,36 +979,42 @@ pub fn simple_eval(
Ok(dt) => match dtype(dt) {
Some(DType::U8 | DType::U32 | DType::I64) => {
bail!(
"unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
"unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}",
node.name
)
}
Some(dt) => dt,
None => {
bail!(
"unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
"unsupported 'dtype' value {dt:?} for {random_type} {}",
node.name
)
}
},
Err(_) => {
bail!(
"unsupported 'dtype' value {dt:?} for RandomUniform {}",
"unsupported 'dtype' value {dt:?} for {random_type} {}",
node.name
)
}
};
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
if seed.is_some() {
bail!("seed for RandomUniform is currently not supported")
bail!("seed for {random_type} is currently not supported")
};
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
.iter()
.map(|x| *x as usize)
.collect();
let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
let output = if random_type == "RandomUniform" {
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?
} else {
let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0);
let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0);
Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)?
};
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),

View File

@ -2020,6 +2020,150 @@ fn test_random_uniform() -> Result<()> {
Ok(())
}
// "RandomNormal"
#[test]
fn test_random_normal() -> Result<()> {
test(vec![3, 2, 1, 4], None, None)?;
test(vec![2, 2, 2, 2], Some(-10.0), None)?;
test(vec![2, 2, 2, 2], None, Some(10.0))?;
test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?;
fn test(shape: Vec<i64>, mean: Option<f32>, scale: Option<f32>) -> Result<()> {
let att_mean = AttributeProto {
name: "mean".to_string(),
ref_attr_name: "mean".to_string(),
i: 0,
doc_string: "mean".to_string(),
r#type: 1, // FLOAT
f: mean.unwrap_or(0.0),
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_scale = AttributeProto {
name: "scale".to_string(),
ref_attr_name: "scale".to_string(),
i: 0,
doc_string: "scale".to_string(),
r#type: 1, // FLOAT
f: scale.unwrap_or(1.0),
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_shape = AttributeProto {
name: "shape".to_string(),
ref_attr_name: "shape".to_string(),
i: 0,
doc_string: "shape".to_string(),
r#type: 7, // INTS
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: shape,
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_dtype = AttributeProto {
name: "dtype".to_string(),
ref_attr_name: "dtype".to_string(),
i: 11, // DOUBLE
doc_string: "dtype".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let attrs = {
let mut mut_attrs = vec![att_shape, att_dtype];
if mean.is_some() {
mut_attrs.push(att_mean);
}
if scale.is_some() {
mut_attrs.push(att_scale);
}
mut_attrs
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "RandomNormal".to_string(),
domain: "".to_string(),
attribute: attrs,
input: vec![],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let data = z.flatten_all()?.to_vec1::<f64>()?;
// test if values are unique
for (i, a) in data.iter().enumerate() {
for (j, b) in data.iter().enumerate() {
if i == j {
continue;
};
assert_ne!(a, b);
}
}
Ok(())
}
Ok(())
}
// "Range"
#[test]
fn test_range() -> Result<()> {

View File

@ -73,13 +73,6 @@ struct RotaryEmbedding {
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
@ -94,7 +87,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
@ -110,10 +102,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
@ -163,10 +153,16 @@ struct Attention {
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
use_flash_attn: bool,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
@ -188,6 +184,7 @@ impl Attention {
head_dim,
rotary_emb,
kv_cache: None,
use_flash_attn,
})
}
@ -231,7 +228,14 @@ impl Attention {
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
let attn_output = {
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = query_states.transpose(1, 2)?;
let k = key_states.transpose(1, 2)?;
let v = value_states.transpose(1, 2)?;
let scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
} else {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
@ -253,6 +257,22 @@ impl Attention {
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
@ -262,8 +282,13 @@ struct DecoderLayer {
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
@ -312,7 +337,7 @@ pub struct Model {
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
@ -320,7 +345,8 @@ impl Model {
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
let layer =
DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;

View File

@ -3,9 +3,7 @@ use std::collections::HashMap;
use candle::quantized::gguf_file;
use candle::quantized::QTensor;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Embedding, RmsNorm};
pub const MAX_SEQ_LEN: usize = 4096;
use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm};
#[derive(Debug, Clone)]
struct QLinear {
@ -70,7 +68,8 @@ struct LayerWeights {
cos: Tensor,
sin: Tensor,
neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
kv_cache: KvCache,
use_flash_attn: bool,
span_attn: tracing::Span,
span_rot: tracing::Span,
}
@ -122,40 +121,55 @@ impl LayerWeights {
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
let k = self.apply_rotary_emb(&k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k.contiguous()?, v.contiguous()?),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k.contiguous()?, v.contiguous()?)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?;
let v = Tensor::cat(&[v_cache, &v], 2)?;
(k.contiguous()?, v.contiguous()?)
}
}
};
self.kv_cache = Some((k.clone(), v.clone()));
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
let y = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = q.to_dtype(DType::BF16)?.transpose(1, 2)?;
let k = k.to_dtype(DType::BF16)?.transpose(1, 2)?;
let v = v.to_dtype(DType::BF16)?.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
.to_dtype(DType::F32)?
.transpose(1, 2)?
} else {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)?
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attn_output.forward(&y)?;
Ok(y)
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug, Clone)]
pub struct ModelWeights {
tok_embeddings: Embedding,
@ -169,6 +183,7 @@ pub struct ModelWeights {
fn precomput_freqs_cis(
head_dim: usize,
max_seq_len: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
@ -177,9 +192,9 @@ fn precomput_freqs_cis(
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.reshape((max_seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
@ -188,6 +203,8 @@ fn precomput_freqs_cis(
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
batch_size: usize,
use_flash_attn: bool,
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
@ -202,16 +219,19 @@ impl ModelWeights {
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize;
let head_dim = embedding_length / head_count;
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
let output = QLinear::new(&ct, reader, "output", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
@ -232,6 +252,12 @@ impl ModelWeights {
)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let kv_cache = KvCache::new(
2,
(batch_size, head_count_kv, max_seq_len, head_dim),
DType::F32,
device,
)?;
layers.push(LayerWeights {
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
@ -240,11 +266,12 @@ impl ModelWeights {
mlp,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
head_dim,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None,
kv_cache,
use_flash_attn,
span_attn,
span_rot,
})

View File

@ -1,5 +1,5 @@
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
@ -250,7 +250,6 @@ pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
sliding_window: usize,
device: Device,
dtype: DType,
@ -269,19 +268,17 @@ impl Model {
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.sliding_window,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
fn prepare_causal_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
@ -301,7 +298,7 @@ impl Model {
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
@ -310,21 +307,42 @@ impl Model {
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
let (b_sz, sql_len) = attn_mask.dims2()?;
let mut mask: Vec<Tensor> = vec![];
for b in 0..b_sz {
mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
}
let mask = Tensor::cat(&mask, 0)?;
let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
.broadcast_as(mask.shape())?
.to_dtype(self.dtype)?;
mask.where_cond(&on_true, &on_false)
}
pub fn forward(
&mut self,
input_ids: &Tensor,
seqlen_offset: usize,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
let attention_mask: Option<Tensor> = match attn_mask {
Some(mask) => Some(self.prepare_attention_mask(mask)?),
None => {
if seq_len <= 1 {
None
} else {
Some(self.prepare_causal_attention_mask(b_size, seq_len, seqlen_offset)?)
}
}
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
xs.apply(&self.norm)
}
pub fn clear_kv_cache(&mut self) {
@ -333,3 +351,32 @@ impl Model {
}
}
}
#[derive(Debug, Clone)]
pub struct ModelForCausalLM {
base_model: Model,
lm_head: Linear,
}
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let base_model = Model::new(cfg, vb)?;
Ok(Self {
base_model,
lm_head,
})
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
self.base_model
.forward(input_ids, seqlen_offset, None)?
.narrow(1, seq_len - 1, 1)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
self.base_model.clear_kv_cache()
}
}

View File

@ -21,7 +21,7 @@ log = { workspace = true }
rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
wav = { workspace = true }
hound = { workspace = true }
safetensors = { workspace = true }
# Wasm specific crates.

View File

@ -345,16 +345,19 @@ impl Decoder {
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
let device = Device::Cpu;
let mut wav_input = std::io::Cursor::new(wav_input);
let (header, data) = wav::read(&mut wav_input)?;
console_log!("loaded wav data: {header:?}");
if header.sampling_rate != m::SAMPLE_RATE as u32 {
let wav_reader = hound::WavReader::new(&mut wav_input)?;
let spec = wav_reader.spec();
console_log!("loaded wav data: {spec:?}");
if spec.sample_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
let mut data = wav_reader.into_samples::<i16>().collect::<Vec<_>>();
data.truncate(data.len() / spec.channels as usize);
let mut pcm_data = Vec::with_capacity(data.len());
for d in data.into_iter() {
let d = d?;
pcm_data.push(d as f32 / 32768.)
}
console_log!("pcm data loaded {}", pcm_data.len());
let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?;
let mel_len = mel.len();