mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
1 Commits
0.5.1
...
improve-sa
Author | SHA1 | Date | |
---|---|---|---|
f7980abbcd |
@ -43,12 +43,11 @@ candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.11.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
hound = "3.5.1"
|
||||
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.24.0", default-features = false }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
@ -70,6 +69,7 @@ tokenizers = { version = "0.19.1", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
wav = "1.0.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "1.1.1", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
@ -408,10 +408,3 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
|
||||
|
||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||
error is generated.
|
||||
|
||||
#### CudaRC error
|
||||
|
||||
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
|
||||
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
|
||||
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
|
||||
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`
|
||||
|
@ -37,6 +37,7 @@ tokenizers = { workspace = true, features = ["onig"] }
|
||||
tracing = { workspace = true }
|
||||
tracing-chrome = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||
parquet = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
@ -5,26 +5,32 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
use candle_core::{Device, Module, Tensor};
|
||||
|
||||
use candle_core::quantized::{QMatMul, QTensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("fp32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||
let _x1 = x.matmul(&x)?;
|
||||
drop(_x1);
|
||||
let start_time = std::time::Instant::now();
|
||||
let _x1 = x.matmul(&x)?;
|
||||
device.synchronize()?;
|
||||
println!("tf32: {:?}", start_time.elapsed());
|
||||
drop(_x1);
|
||||
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q = QMatMul::from_qtensor(q)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
||||
let res_q_cuda = q.forward(&x)?;
|
||||
println!("{res_q_cuda}");
|
||||
|
||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
||||
let res_q_cpu = q_cpu.forward(&x_cpu)?;
|
||||
println!("{res_q_cpu}");
|
||||
|
||||
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
|
||||
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?;
|
||||
println!("{diff}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1615,8 +1615,12 @@ impl BackendStorage for CudaStorage {
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||
.w()?;
|
||||
unsafe {
|
||||
self.device
|
||||
.blas
|
||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
||||
}
|
||||
.w()?;
|
||||
CudaStorageSlice::F32(out)
|
||||
}
|
||||
(CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => {
|
||||
@ -1813,20 +1817,6 @@ static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||
std::sync::atomic::AtomicBool::new(false);
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn gemm_reduced_precision_f32() -> bool {
|
||||
MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f32(b: bool) {
|
||||
MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with f16 GEMMs.
|
||||
@ -1852,51 +1842,6 @@ pub fn set_gemm_reduced_precision_bf16(b: bool) {
|
||||
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_f32(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<f32>,
|
||||
a: &cudarc::driver::CudaView<f32>,
|
||||
b: &cudarc::driver::CudaView<f32>,
|
||||
c: &mut CudaSlice<f32>,
|
||||
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
||||
use cudarc::cublas::sys;
|
||||
use cudarc::driver::DevicePtrMut;
|
||||
|
||||
let compute_type = if gemm_reduced_precision_f32() {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
|
||||
} else {
|
||||
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||
};
|
||||
let alpha = &cfg.gemm.alpha as *const f32 as *const _;
|
||||
let beta = &cfg.gemm.beta as *const f32 as *const _;
|
||||
|
||||
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||
*cublas.handle(),
|
||||
cfg.gemm.transa,
|
||||
cfg.gemm.transb,
|
||||
cfg.gemm.m,
|
||||
cfg.gemm.n,
|
||||
cfg.gemm.k,
|
||||
alpha,
|
||||
*a.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.lda,
|
||||
cfg.stride_a,
|
||||
*b.device_ptr() as *const _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.ldb,
|
||||
cfg.stride_b,
|
||||
beta,
|
||||
*c.device_ptr_mut() as *mut _,
|
||||
sys::cudaDataType_t::CUDA_R_32F,
|
||||
cfg.gemm.ldc,
|
||||
cfg.stride_c,
|
||||
cfg.batch_size,
|
||||
compute_type,
|
||||
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe fn gemm_strided_batched_f16(
|
||||
cublas: &cudarc::cublas::CudaBlas,
|
||||
cfg: StridedBatchedConfig<f16>,
|
||||
|
@ -258,13 +258,3 @@ pub fn gemm_reduced_precision_bf16() -> bool {
|
||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||
/// allowed with bf16 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn gemm_reduced_precision_f32() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||
/// allowed with f32 GEMMs.
|
||||
pub fn set_gemm_reduced_precision_f32(_b: bool) {}
|
||||
|
@ -100,11 +100,11 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = command_buffer_lock.to_owned();
|
||||
let mut index = self
|
||||
.command_buffer_index
|
||||
.write()
|
||||
.try_write()
|
||||
.map_err(MetalError::from)?;
|
||||
if *index > self.compute_per_buffer {
|
||||
command_buffer.commit();
|
||||
@ -119,7 +119,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn wait_until_completed(&self) -> Result<()> {
|
||||
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||
match command_buffer.status() {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
@ -179,7 +179,7 @@ impl MetalDevice {
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
@ -232,7 +232,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
@ -251,7 +251,7 @@ impl MetalDevice {
|
||||
option: MTLResourceOptions,
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
|
@ -6,7 +6,7 @@ use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
|
||||
mod device;
|
||||
pub use device::{DeviceId, MetalDevice};
|
||||
@ -36,12 +36,6 @@ impl<T> From<TryLockError<T>> for MetalError {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<PoisonError<T>> for MetalError {
|
||||
fn from(p: PoisonError<T>) -> Self {
|
||||
MetalError::LockError(LockError::Poisoned(p.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum MetalError {
|
||||
|
@ -349,30 +349,6 @@ impl MmapedSafetensors {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SliceSafetensors<'a> {
|
||||
safetensors: SafeTensors<'a>,
|
||||
}
|
||||
|
||||
impl<'a> SliceSafetensors<'a> {
|
||||
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||
pub fn new(buffer: &'a [u8]) -> Result<Self> {
|
||||
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
|
||||
Ok(Self { safetensors })
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.safetensors.tensor(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
self.safetensors.tensors()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
Ok(self.safetensors.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BufferedSafetensors {
|
||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||
}
|
||||
|
@ -235,66 +235,4 @@ impl Tensor {
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Set the values on `self` using values from `src`. The copy starts at the specified
|
||||
/// `offset` for the target dimension `dim` on `self`.
|
||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||
///
|
||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||
/// back-propagation.
|
||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||
if !self.is_contiguous() || !src.is_contiguous() {
|
||||
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
|
||||
}
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device().location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: self.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
|
||||
if dim_idx == dim && *v2 + offset > *v1 {
|
||||
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
|
||||
}
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
|
||||
}
|
||||
}
|
||||
let block_size: usize = src.dims().iter().skip(1 + dim).product();
|
||||
let d1: usize = src.dims().iter().take(dim).product();
|
||||
let d2 = block_size * src.dims()[dim];
|
||||
let dst_o = self.layout().start_offset() + offset * block_size;
|
||||
let src_o = src.layout().start_offset();
|
||||
src.storage().copy2d(
|
||||
&mut self.storage_mut(),
|
||||
d1,
|
||||
d2,
|
||||
/* src_s */ d2,
|
||||
/* dst_s */ block_size * self.dims()[dim],
|
||||
src_o,
|
||||
dst_o,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -1,31 +1,5 @@
|
||||
use candle_core::{DType, Result, Tensor};
|
||||
|
||||
struct TmpFile(std::path::PathBuf);
|
||||
|
||||
impl TmpFile {
|
||||
fn create(base: &str) -> TmpFile {
|
||||
let filename = std::env::temp_dir().join(format!(
|
||||
"candle-{}-{}-{:?}",
|
||||
base,
|
||||
std::process::id(),
|
||||
std::thread::current().id(),
|
||||
));
|
||||
TmpFile(filename)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<std::path::Path> for TmpFile {
|
||||
fn as_ref(&self) -> &std::path::Path {
|
||||
self.0.as_path()
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TmpFile {
|
||||
fn drop(&mut self) {
|
||||
std::fs::remove_file(&self.0).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn npy() -> Result<()> {
|
||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||
@ -48,24 +22,3 @@ fn npz() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn safetensors() -> Result<()> {
|
||||
use candle_core::safetensors::Load;
|
||||
|
||||
let tmp_file = TmpFile::create("st");
|
||||
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
|
||||
t.save_safetensors("t", &tmp_file)?;
|
||||
// Load from file.
|
||||
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
|
||||
let t2 = st.get("t").unwrap();
|
||||
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0f32);
|
||||
// Load from bytes.
|
||||
let bytes = std::fs::read(tmp_file)?;
|
||||
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
|
||||
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
|
||||
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0f32);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -665,30 +665,6 @@ fn broadcast(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_set(device: &Device) -> Result<()> {
|
||||
let (b, h, max_t, d) = (2, 4, 7, 3);
|
||||
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
|
||||
cache.slice_set(&tensor, 2, 0)?;
|
||||
let cache_t = cache.narrow(2, 0, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
cache.slice_set(&tensor, 2, 1)?;
|
||||
let cache_t = cache.narrow(2, 1, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
|
||||
cache.slice_set(&ones, 2, 6)?;
|
||||
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let diff = (cache.narrow(2, 6, 1)? - 1.)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cat(device: &Device) -> Result<()> {
|
||||
// 1D
|
||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
@ -1170,7 +1146,6 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
|
@ -193,9 +193,6 @@ struct Args {
|
||||
/// The model to use.
|
||||
#[arg(long, default_value = "2b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -273,7 +270,7 @@ fn main() -> Result<()> {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(args.use_flash_attn, &config, vb)?;
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
|
@ -1,19 +0,0 @@
|
||||
# gte-Qwen1.5-7B-instruct
|
||||
|
||||
gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.
|
||||
|
||||
- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
|
||||
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
|
||||
|
||||
|
||||
## Running the example
|
||||
|
||||
Automatically download the model from the HuggingFace hub:
|
||||
```bash
|
||||
$ cargo run --example gte-qwen --release
|
||||
```
|
||||
|
||||
or, load the model from a local directory:
|
||||
```bash
|
||||
cargo run --example gte-qwen --release --features cuda -- --local-repo /path/to/gte_Qwen1.5-7B-instruct/
|
||||
```
|
@ -1,178 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config, Model};
|
||||
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::{
|
||||
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
|
||||
Tokenizer,
|
||||
};
|
||||
|
||||
// gte-Qwen1.5-7B-instruct use EOS token as padding token
|
||||
const EOS_TOKEN: &str = "<|endoftext|>";
|
||||
const EOS_TOKEN_ID: u32 = 151643;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long, default_value = "Alibaba-NLP/gte-Qwen1.5-7B-instruct")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
local_repo: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConfigFiles {
|
||||
pub config: std::path::PathBuf,
|
||||
pub tokenizer: std::path::PathBuf,
|
||||
pub weights: Vec<std::path::PathBuf>,
|
||||
}
|
||||
|
||||
// Loading the model from the HuggingFace Hub. Network access is required.
|
||||
fn load_from_hub(model_id: &str, revision: &str) -> Result<ConfigFiles> {
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id.to_string(),
|
||||
RepoType::Model,
|
||||
revision.to_string(),
|
||||
));
|
||||
Ok(ConfigFiles {
|
||||
config: repo.get("config.json")?,
|
||||
tokenizer: repo.get("tokenizer.json")?,
|
||||
weights: candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
})
|
||||
}
|
||||
|
||||
// Loading the model from a local directory.
|
||||
fn load_from_local(local_path: &str) -> Result<ConfigFiles> {
|
||||
let local_path = std::path::PathBuf::from(local_path);
|
||||
let weight_path = local_path.join("model.safetensors.index.json");
|
||||
let json: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(weight_path)?)?;
|
||||
let weight_map = match json.get("weight_map") {
|
||||
Some(serde_json::Value::Object(map)) => map,
|
||||
Some(_) => panic!("`weight map` is not a map"),
|
||||
None => panic!("`weight map` not found"),
|
||||
};
|
||||
let mut safetensors_files = std::collections::HashSet::new();
|
||||
for value in weight_map.values() {
|
||||
safetensors_files.insert(
|
||||
value
|
||||
.as_str()
|
||||
.expect("Weight files should be parsed as strings"),
|
||||
);
|
||||
}
|
||||
let safetensors_paths = safetensors_files
|
||||
.iter()
|
||||
.map(|v| local_path.join(v))
|
||||
.collect::<Vec<_>>();
|
||||
Ok(ConfigFiles {
|
||||
config: local_path.join("config.json"),
|
||||
tokenizer: local_path.join("tokenizer.json"),
|
||||
weights: safetensors_paths,
|
||||
})
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Fetch the model. Do this offline if local path provided.
|
||||
println!("Fetching model files...");
|
||||
let start = std::time::Instant::now();
|
||||
let config_files = match args.local_repo {
|
||||
Some(local_path) => load_from_local(&local_path)?,
|
||||
None => load_from_hub(&args.model_id, &args.revision)?,
|
||||
};
|
||||
println!("Model file retrieved in {:?}", start.elapsed());
|
||||
|
||||
// Inputs will be padded to the longest sequence in the batch.
|
||||
let padding = PaddingParams {
|
||||
strategy: PaddingStrategy::BatchLongest,
|
||||
direction: PaddingDirection::Left,
|
||||
pad_to_multiple_of: None,
|
||||
pad_id: EOS_TOKEN_ID,
|
||||
pad_type_id: 0,
|
||||
pad_token: String::from(EOS_TOKEN),
|
||||
};
|
||||
|
||||
// Tokenizer setup
|
||||
let mut tokenizer = Tokenizer::from_file(config_files.tokenizer).map_err(E::msg)?;
|
||||
tokenizer.with_padding(Some(padding));
|
||||
|
||||
// Model initialization
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_files.config)?)?;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&config_files.weights, dtype, &device)? };
|
||||
let mut model = Model::new(&config, vb)?;
|
||||
println!("Model loaded in {:?}", start.elapsed());
|
||||
|
||||
// Encode the queries and the targets
|
||||
let instruct = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ";
|
||||
let documents = vec![
|
||||
format!("{instruct}how much protein should a female eat{EOS_TOKEN}"),
|
||||
format!("{instruct}summit define{EOS_TOKEN}"),
|
||||
format!("As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.{EOS_TOKEN}"),
|
||||
format!("Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.{EOS_TOKEN}"),
|
||||
];
|
||||
let encoded = tokenizer.encode_batch(documents, true).map_err(E::msg)?;
|
||||
let tokens: Vec<&[u32]> = encoded.iter().map(|x| x.get_ids()).collect();
|
||||
let tokens = Tensor::new(tokens, &device)?;
|
||||
let mask: Vec<&[u32]> = encoded.iter().map(|x| x.get_attention_mask()).collect();
|
||||
let mask = Tensor::new(mask, &device)?;
|
||||
|
||||
// Inference
|
||||
let start_gen = std::time::Instant::now();
|
||||
let logits = model.forward(&tokens, 0, Some(&mask))?;
|
||||
|
||||
// Extract the last hidden states as embeddings since inputs are padded left.
|
||||
let (_, seq_len, _) = logits.dims3()?;
|
||||
let embd = logits
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.squeeze(1)?
|
||||
.to_dtype(DType::F32)?;
|
||||
|
||||
// Calculate the relativity scores. Note the embeddings should be normalized.
|
||||
let norm = embd.broadcast_div(&embd.sqr()?.sum_keepdim(1)?.sqrt()?)?;
|
||||
let scores = norm.narrow(0, 0, 2)?.matmul(&norm.narrow(0, 2, 2)?.t()?)?;
|
||||
|
||||
// Print the results
|
||||
println!("Embedding done in {:?}", start_gen.elapsed());
|
||||
println!("Scores: {:?}", scores.to_vec2::<f32>()?);
|
||||
|
||||
Ok(())
|
||||
}
|
@ -90,9 +90,6 @@ struct Args {
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "phi-3b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -216,13 +213,7 @@ fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
match args.which {
|
||||
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
||||
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
|
||||
1,
|
||||
args.use_flash_attn,
|
||||
model,
|
||||
&mut file,
|
||||
&device,
|
||||
)?),
|
||||
Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?),
|
||||
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
|
||||
}
|
||||
};
|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
||||
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
|
||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
|
@ -39,7 +39,7 @@ struct Args {
|
||||
|
||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||
/// mask, positive makes the mask more selective.
|
||||
#[arg(long, allow_hyphen_values = true, default_value_t = 0.)]
|
||||
#[arg(long, default_value_t = 0.)]
|
||||
threshold: f32,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
|
@ -42,10 +42,6 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
}
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
|
@ -139,9 +139,7 @@ impl FlashAttn {
|
||||
|
||||
let elem_count = out_shape.elem_count();
|
||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||
let softmax_lse = dev
|
||||
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
|
||||
.w()?;
|
||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
||||
|
||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||
|
||||
|
@ -1,101 +0,0 @@
|
||||
use candle::{DType, Device, Result, Shape, Tensor};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
all_data: Tensor,
|
||||
dim: usize,
|
||||
current_seq_len: usize,
|
||||
max_seq_len: usize,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
||||
dim: D,
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let dim = dim.to_index(&shape, "kv-cache")?;
|
||||
let max_seq_len = shape.dims()[dim];
|
||||
let all_data = Tensor::zeros(shape, dtype, dev)?;
|
||||
Ok(Self {
|
||||
all_data,
|
||||
dim,
|
||||
current_seq_len: 0,
|
||||
max_seq_len,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
|
||||
pub fn current_seq_len(&self) -> usize {
|
||||
self.current_seq_len
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_seq_len
|
||||
}
|
||||
|
||||
pub fn all_data(&self) -> &Tensor {
|
||||
&self.all_data
|
||||
}
|
||||
|
||||
pub fn current_data(&self) -> Result<Tensor> {
|
||||
self.all_data.narrow(self.dim, 0, self.current_seq_len)
|
||||
}
|
||||
|
||||
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
||||
let seq_len = src.dim(self.dim)?;
|
||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||
candle::bail!(
|
||||
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
||||
self.current_seq_len,
|
||||
self.max_seq_len
|
||||
)
|
||||
}
|
||||
self.all_data
|
||||
.slice_set(src, self.dim, self.current_seq_len)?;
|
||||
self.current_seq_len += seq_len;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KvCache {
|
||||
k: Cache,
|
||||
v: Cache,
|
||||
}
|
||||
|
||||
impl KvCache {
|
||||
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
||||
dim: D,
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let dim = dim.to_index(&shape, "kv-cache")?;
|
||||
let k = Cache::new(dim, &shape, dtype, dev)?;
|
||||
let v = Cache::new(dim, &shape, dtype, dev)?;
|
||||
Ok(Self { k, v })
|
||||
}
|
||||
|
||||
pub fn k(&self) -> Result<Tensor> {
|
||||
self.k.current_data()
|
||||
}
|
||||
|
||||
pub fn v(&self) -> Result<Tensor> {
|
||||
self.v.current_data()
|
||||
}
|
||||
|
||||
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
self.k.append(k)?;
|
||||
self.v.append(v)?;
|
||||
let k = self.k.current_data()?;
|
||||
let v = self.v.current_data()?;
|
||||
Ok((k, v))
|
||||
}
|
||||
}
|
@ -6,7 +6,6 @@ pub mod encoding;
|
||||
pub mod func;
|
||||
pub mod group_norm;
|
||||
pub mod init;
|
||||
pub mod kv_cache;
|
||||
pub mod layer_norm;
|
||||
pub mod linear;
|
||||
pub mod loss;
|
||||
|
@ -422,32 +422,6 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> SimpleBackend for candle::safetensors::SliceSafetensors<'a> {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
_: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
|
||||
if tensor.shape() != &s {
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg: format!("shape mismatch for {name}"),
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.get(name).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
/// Initializes a `VarBuilder` using a custom backend.
|
||||
///
|
||||
@ -507,18 +481,12 @@ impl<'a> VarBuilder<'a> {
|
||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` from a binary buffer in the safetensor format.
|
||||
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
||||
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` from a binary slice in the safetensor format.
|
||||
pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let tensors = candle::safetensors::SliceSafetensors::new(data)?;
|
||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let npz = candle::npy::NpzTensors::new(p)?;
|
||||
|
@ -971,7 +971,7 @@ pub fn simple_eval(
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
random_type @ ("RandomUniform" | "RandomNormal") => {
|
||||
"RandomUniform" => {
|
||||
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
||||
// type by
|
||||
// default
|
||||
@ -979,42 +979,36 @@ pub fn simple_eval(
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(DType::U8 | DType::U32 | DType::I64) => {
|
||||
bail!(
|
||||
"unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}",
|
||||
"unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
bail!(
|
||||
"unsupported 'dtype' value {dt:?} for {random_type} {}",
|
||||
"unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
bail!(
|
||||
"unsupported 'dtype' value {dt:?} for {random_type} {}",
|
||||
"unsupported 'dtype' value {dt:?} for RandomUniform {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
};
|
||||
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
|
||||
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
|
||||
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
|
||||
if seed.is_some() {
|
||||
bail!("seed for {random_type} is currently not supported")
|
||||
bail!("seed for RandomUniform is currently not supported")
|
||||
};
|
||||
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
|
||||
.iter()
|
||||
.map(|x| *x as usize)
|
||||
.collect();
|
||||
let output = if random_type == "RandomUniform" {
|
||||
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
|
||||
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
|
||||
Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?
|
||||
} else {
|
||||
let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0);
|
||||
let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0);
|
||||
Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)?
|
||||
};
|
||||
let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
|
@ -2020,150 +2020,6 @@ fn test_random_uniform() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "RandomNormal"
|
||||
#[test]
|
||||
fn test_random_normal() -> Result<()> {
|
||||
test(vec![3, 2, 1, 4], None, None)?;
|
||||
test(vec![2, 2, 2, 2], Some(-10.0), None)?;
|
||||
test(vec![2, 2, 2, 2], None, Some(10.0))?;
|
||||
test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?;
|
||||
|
||||
fn test(shape: Vec<i64>, mean: Option<f32>, scale: Option<f32>) -> Result<()> {
|
||||
let att_mean = AttributeProto {
|
||||
name: "mean".to_string(),
|
||||
ref_attr_name: "mean".to_string(),
|
||||
i: 0,
|
||||
doc_string: "mean".to_string(),
|
||||
r#type: 1, // FLOAT
|
||||
f: mean.unwrap_or(0.0),
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let att_scale = AttributeProto {
|
||||
name: "scale".to_string(),
|
||||
ref_attr_name: "scale".to_string(),
|
||||
i: 0,
|
||||
doc_string: "scale".to_string(),
|
||||
r#type: 1, // FLOAT
|
||||
f: scale.unwrap_or(1.0),
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let att_shape = AttributeProto {
|
||||
name: "shape".to_string(),
|
||||
ref_attr_name: "shape".to_string(),
|
||||
i: 0,
|
||||
doc_string: "shape".to_string(),
|
||||
r#type: 7, // INTS
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: shape,
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let att_dtype = AttributeProto {
|
||||
name: "dtype".to_string(),
|
||||
ref_attr_name: "dtype".to_string(),
|
||||
i: 11, // DOUBLE
|
||||
doc_string: "dtype".to_string(),
|
||||
r#type: 2, // INT
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let attrs = {
|
||||
let mut mut_attrs = vec![att_shape, att_dtype];
|
||||
if mean.is_some() {
|
||||
mut_attrs.push(att_mean);
|
||||
}
|
||||
if scale.is_some() {
|
||||
mut_attrs.push(att_scale);
|
||||
}
|
||||
mut_attrs
|
||||
};
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "RandomNormal".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: attrs,
|
||||
input: vec![],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let data = z.flatten_all()?.to_vec1::<f64>()?;
|
||||
|
||||
// test if values are unique
|
||||
for (i, a) in data.iter().enumerate() {
|
||||
for (j, b) in data.iter().enumerate() {
|
||||
if i == j {
|
||||
continue;
|
||||
};
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Range"
|
||||
#[test]
|
||||
fn test_range() -> Result<()> {
|
||||
|
@ -1,4 +1,4 @@
|
||||
use candle::{DType, Error, Result, Tensor};
|
||||
use candle::{DType, Error, IndexOp, Result, Tensor, D};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
@ -73,17 +73,15 @@ impl LogitsProcessor {
|
||||
}
|
||||
|
||||
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||
fn sample_topk(&mut self, prs: &mut Vec<f32>, top_k: usize) -> Result<u32> {
|
||||
if top_k >= prs.len() {
|
||||
self.sample_multinomial(prs)
|
||||
} else {
|
||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||
let (indices, _, _) =
|
||||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||
let index = self.sample_multinomial(&prs)?;
|
||||
Ok(indices[index as usize] as u32)
|
||||
}
|
||||
fn sample_topk(&mut self, logits: &Tensor, top_k: usize, temperature: f64) -> Result<u32> {
|
||||
let arg_sort = logits.arg_sort_last_dim(false)?;
|
||||
let top_k_indices = arg_sort.narrow(candle::D::Minus1, 0, top_k)?;
|
||||
let top_k_logits = logits.gather(&top_k_indices, D::Minus1)?;
|
||||
let top_k_logits = (&top_k_logits / temperature)?;
|
||||
let top_k_prs = candle_nn::ops::softmax_last_dim(&top_k_logits)?;
|
||||
let top_k_prs = top_k_prs.to_vec1()?;
|
||||
let index = self.sample_multinomial(&top_k_prs)?;
|
||||
Ok(top_k_indices.i(index as usize)?.to_vec0::<u32>()?)
|
||||
}
|
||||
|
||||
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||
@ -137,8 +135,12 @@ impl LogitsProcessor {
|
||||
}
|
||||
}
|
||||
Sampling::TopK { k, temperature } => {
|
||||
let mut prs = prs(*temperature)?;
|
||||
self.sample_topk(&mut prs, *k)?
|
||||
if *k >= logits.dim(D::Minus1)? {
|
||||
let prs = prs(*temperature)?;
|
||||
self.sample_multinomial(&prs)?
|
||||
} else {
|
||||
self.sample_topk(&logits, *k, *temperature)?
|
||||
}
|
||||
}
|
||||
Sampling::TopKThenTopP { k, p, temperature } => {
|
||||
let mut prs = prs(*temperature)?;
|
||||
|
@ -73,6 +73,13 @@ struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
||||
let last_dim = xs.dim(D::Minus1)?;
|
||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
@ -87,6 +94,7 @@ impl RotaryEmbedding {
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
@ -102,8 +110,10 @@ impl RotaryEmbedding {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
@ -153,16 +163,10 @@ struct Attention {
|
||||
head_dim: usize,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
@ -184,7 +188,6 @@ impl Attention {
|
||||
head_dim,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
@ -228,14 +231,7 @@ impl Attention {
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = query_states.transpose(1, 2)?;
|
||||
let k = key_states.transpose(1, 2)?;
|
||||
let v = value_states.transpose(1, 2)?;
|
||||
let scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
|
||||
} else {
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
@ -257,22 +253,6 @@ impl Attention {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
@ -282,13 +262,8 @@ struct DecoderLayer {
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
|
||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
@ -337,7 +312,7 @@ pub struct Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
@ -345,8 +320,7 @@ impl Model {
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer =
|
||||
DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
|
||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
|
@ -3,7 +3,9 @@ use std::collections::HashMap;
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm};
|
||||
use candle_nn::{Embedding, RmsNorm};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QLinear {
|
||||
@ -68,8 +70,7 @@ struct LayerWeights {
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
neg_inf: Tensor,
|
||||
kv_cache: KvCache,
|
||||
use_flash_attn: bool,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span_attn: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
}
|
||||
@ -121,55 +122,40 @@ impl LayerWeights {
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
|
||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k.contiguous()?, v.contiguous()?),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
|
||||
let y = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = q.to_dtype(DType::BF16)?.transpose(1, 2)?;
|
||||
let k = k.to_dtype(DType::BF16)?.transpose(1, 2)?;
|
||||
let v = v.to_dtype(DType::BF16)?.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||
.to_dtype(DType::F32)?
|
||||
.transpose(1, 2)?
|
||||
} else {
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = match mask {
|
||||
None => att,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, &self.neg_inf)?
|
||||
}
|
||||
};
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let att = match mask {
|
||||
None => att,
|
||||
Some(mask) => {
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
masked_fill(&att, &mask, &self.neg_inf)?
|
||||
}
|
||||
};
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
let y = self.attn_output.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
@ -183,7 +169,6 @@ pub struct ModelWeights {
|
||||
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
max_seq_len: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
@ -192,9 +177,9 @@ fn precomput_freqs_cis(
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
@ -203,8 +188,6 @@ fn precomput_freqs_cis(
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
batch_size: usize,
|
||||
use_flash_attn: bool,
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
@ -219,19 +202,16 @@ impl ModelWeights {
|
||||
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
|
||||
let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize;
|
||||
let head_dim = embedding_length / head_count;
|
||||
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
|
||||
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
|
||||
let output = QLinear::new(&ct, reader, "output", device)?;
|
||||
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
@ -252,12 +232,6 @@ impl ModelWeights {
|
||||
)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let kv_cache = KvCache::new(
|
||||
2,
|
||||
(batch_size, head_count_kv, max_seq_len, head_dim),
|
||||
DType::F32,
|
||||
device,
|
||||
)?;
|
||||
layers.push(LayerWeights {
|
||||
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
||||
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
||||
@ -266,12 +240,11 @@ impl ModelWeights {
|
||||
mlp,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim,
|
||||
head_dim: embedding_length / head_count,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache,
|
||||
use_flash_attn,
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
})
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -250,6 +250,7 @@ pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
sliding_window: usize,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
@ -268,17 +269,19 @@ impl Model {
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
sliding_window: cfg.sliding_window,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_causal_attention_mask(
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
@ -298,7 +301,7 @@ impl Model {
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
@ -307,42 +310,21 @@ impl Model {
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, sql_len) = attn_mask.dims2()?;
|
||||
let mut mask: Vec<Tensor> = vec![];
|
||||
for b in 0..b_sz {
|
||||
mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
|
||||
}
|
||||
let mask = Tensor::cat(&mask, 0)?;
|
||||
let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
|
||||
let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
|
||||
.broadcast_as(mask.shape())?
|
||||
.to_dtype(self.dtype)?;
|
||||
mask.where_cond(&on_true, &on_false)
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
attn_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask: Option<Tensor> = match attn_mask {
|
||||
Some(mask) => Some(self.prepare_attention_mask(mask)?),
|
||||
None => {
|
||||
if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(self.prepare_causal_attention_mask(b_size, seq_len, seqlen_offset)?)
|
||||
}
|
||||
}
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
xs.apply(&self.norm)
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
@ -351,32 +333,3 @@ impl Model {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelForCausalLM {
|
||||
base_model: Model,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl ModelForCausalLM {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let base_model = Model::new(cfg, vb)?;
|
||||
Ok(Self {
|
||||
base_model,
|
||||
lm_head,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||
self.base_model
|
||||
.forward(input_ids, seqlen_offset, None)?
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.base_model.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ log = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
hound = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
|
||||
# Wasm specific crates.
|
||||
|
@ -345,19 +345,16 @@ impl Decoder {
|
||||
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
|
||||
let device = Device::Cpu;
|
||||
let mut wav_input = std::io::Cursor::new(wav_input);
|
||||
let wav_reader = hound::WavReader::new(&mut wav_input)?;
|
||||
let spec = wav_reader.spec();
|
||||
console_log!("loaded wav data: {spec:?}");
|
||||
if spec.sample_rate != m::SAMPLE_RATE as u32 {
|
||||
let (header, data) = wav::read(&mut wav_input)?;
|
||||
console_log!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
|
||||
}
|
||||
let mut data = wav_reader.into_samples::<i16>().collect::<Vec<_>>();
|
||||
data.truncate(data.len() / spec.channels as usize);
|
||||
let mut pcm_data = Vec::with_capacity(data.len());
|
||||
for d in data.into_iter() {
|
||||
let d = d?;
|
||||
pcm_data.push(d as f32 / 32768.)
|
||||
}
|
||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||
.iter()
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
console_log!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?;
|
||||
let mel_len = mel.len();
|
||||
|
Reference in New Issue
Block a user