mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Compare commits
15 Commits
improve-sa
...
0.5.1
Author | SHA1 | Date | |
---|---|---|---|
7ff921c538 | |||
9b8537a62f | |||
7ebc3548e1 | |||
eefc1c77ef | |||
01545f7303 | |||
349c3e806a | |||
bdaa34216a | |||
cc80e065e5 | |||
13c64f6828 | |||
21f82a5155 | |||
9cff7bc3f4 | |||
d9bc5ec151 | |||
84328e2b60 | |||
82b641fd27 | |||
01794dc16e |
@ -43,11 +43,12 @@ candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
|
|||||||
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
|
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
cudarc = { version = "0.11.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
|
hound = "3.5.1"
|
||||||
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
image = { version = "0.25.0", default-features = false, features = ["jpeg", "png"] }
|
||||||
imageproc = { version = "0.24.0", default-features = false }
|
imageproc = { version = "0.24.0", default-features = false }
|
||||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||||
@ -69,7 +70,6 @@ tokenizers = { version = "0.19.1", default-features = false }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
wav = "1.0.0"
|
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "1.1.1", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
@ -408,3 +408,10 @@ This may be caused by the models being loaded from `/mnt/c`, more details on
|
|||||||
|
|
||||||
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
You can set `RUST_BACKTRACE=1` to be provided with backtraces when a candle
|
||||||
error is generated.
|
error is generated.
|
||||||
|
|
||||||
|
#### CudaRC error
|
||||||
|
|
||||||
|
If you encounter an error like this one `called `Result::unwrap()` on an `Err` value: LoadLibraryExW { source: Os { code: 126, kind: Uncategorized, message: "The specified module could not be found." } }` on windows. To fix copy and rename these 3 files (make sure they are in path). The paths depend on your cuda version.
|
||||||
|
`c:\Windows\System32\nvcuda.dll` -> `cuda.dll`
|
||||||
|
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\cublas64_12.dll` -> `cublas.dll`
|
||||||
|
`c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin\curand64_10.dll` -> `curand.dll`
|
||||||
|
@ -37,7 +37,6 @@ tokenizers = { workspace = true, features = ["onig"] }
|
|||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
wav = { workspace = true }
|
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
parquet = { workspace = true }
|
parquet = { workspace = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
|
@ -5,32 +5,26 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Module, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
use candle_core::quantized::{QMatMul, QTensor};
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?;
|
||||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
let _x1 = x.matmul(&x)?;
|
||||||
let q = QMatMul::from_qtensor(q)?;
|
drop(_x1);
|
||||||
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
let start_time = std::time::Instant::now();
|
||||||
let res_q_cuda = q.forward(&x)?;
|
let _x1 = x.matmul(&x)?;
|
||||||
println!("{res_q_cuda}");
|
device.synchronize()?;
|
||||||
|
println!("fp32: {:?}", start_time.elapsed());
|
||||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
drop(_x1);
|
||||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
let _x1 = x.matmul(&x)?;
|
||||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
drop(_x1);
|
||||||
let res_q_cpu = q_cpu.forward(&x_cpu)?;
|
let start_time = std::time::Instant::now();
|
||||||
println!("{res_q_cpu}");
|
let _x1 = x.matmul(&x)?;
|
||||||
|
device.synchronize()?;
|
||||||
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
|
println!("tf32: {:?}", start_time.elapsed());
|
||||||
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
|
drop(_x1);
|
||||||
.abs()?
|
|
||||||
.flatten_all()?
|
|
||||||
.max(0)?;
|
|
||||||
println!("{diff}");
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1615,11 +1615,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||||
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||||
unsafe {
|
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||||
self.device
|
|
||||||
.blas
|
|
||||||
.gemm_strided_batched(cfg, rhs, lhs, &mut out)
|
|
||||||
}
|
|
||||||
.w()?;
|
.w()?;
|
||||||
CudaStorageSlice::F32(out)
|
CudaStorageSlice::F32(out)
|
||||||
}
|
}
|
||||||
@ -1817,6 +1813,20 @@ static MM_F16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
|||||||
std::sync::atomic::AtomicBool::new(false);
|
std::sync::atomic::AtomicBool::new(false);
|
||||||
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
static MM_BF16_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||||
std::sync::atomic::AtomicBool::new(false);
|
std::sync::atomic::AtomicBool::new(false);
|
||||||
|
static MM_F32_REDUCED_PRECISION: std::sync::atomic::AtomicBool =
|
||||||
|
std::sync::atomic::AtomicBool::new(false);
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||||
|
/// allowed with f32 GEMMs.
|
||||||
|
pub fn gemm_reduced_precision_f32() -> bool {
|
||||||
|
MM_F32_REDUCED_PRECISION.load(std::sync::atomic::Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||||
|
/// allowed with f32 GEMMs.
|
||||||
|
pub fn set_gemm_reduced_precision_f32(b: bool) {
|
||||||
|
MM_F32_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
/// allowed with f16 GEMMs.
|
/// allowed with f16 GEMMs.
|
||||||
@ -1842,6 +1852,51 @@ pub fn set_gemm_reduced_precision_bf16(b: bool) {
|
|||||||
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
MM_BF16_REDUCED_PRECISION.store(b, std::sync::atomic::Ordering::Relaxed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe fn gemm_strided_batched_f32(
|
||||||
|
cublas: &cudarc::cublas::CudaBlas,
|
||||||
|
cfg: StridedBatchedConfig<f32>,
|
||||||
|
a: &cudarc::driver::CudaView<f32>,
|
||||||
|
b: &cudarc::driver::CudaView<f32>,
|
||||||
|
c: &mut CudaSlice<f32>,
|
||||||
|
) -> std::result::Result<(), cudarc::cublas::result::CublasError> {
|
||||||
|
use cudarc::cublas::sys;
|
||||||
|
use cudarc::driver::DevicePtrMut;
|
||||||
|
|
||||||
|
let compute_type = if gemm_reduced_precision_f32() {
|
||||||
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_TF32
|
||||||
|
} else {
|
||||||
|
sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
|
||||||
|
};
|
||||||
|
let alpha = &cfg.gemm.alpha as *const f32 as *const _;
|
||||||
|
let beta = &cfg.gemm.beta as *const f32 as *const _;
|
||||||
|
|
||||||
|
cudarc::cublas::result::gemm_strided_batched_ex(
|
||||||
|
*cublas.handle(),
|
||||||
|
cfg.gemm.transa,
|
||||||
|
cfg.gemm.transb,
|
||||||
|
cfg.gemm.m,
|
||||||
|
cfg.gemm.n,
|
||||||
|
cfg.gemm.k,
|
||||||
|
alpha,
|
||||||
|
*a.device_ptr() as *const _,
|
||||||
|
sys::cudaDataType_t::CUDA_R_32F,
|
||||||
|
cfg.gemm.lda,
|
||||||
|
cfg.stride_a,
|
||||||
|
*b.device_ptr() as *const _,
|
||||||
|
sys::cudaDataType_t::CUDA_R_32F,
|
||||||
|
cfg.gemm.ldb,
|
||||||
|
cfg.stride_b,
|
||||||
|
beta,
|
||||||
|
*c.device_ptr_mut() as *mut _,
|
||||||
|
sys::cudaDataType_t::CUDA_R_32F,
|
||||||
|
cfg.gemm.ldc,
|
||||||
|
cfg.stride_c,
|
||||||
|
cfg.batch_size,
|
||||||
|
compute_type,
|
||||||
|
sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
unsafe fn gemm_strided_batched_f16(
|
unsafe fn gemm_strided_batched_f16(
|
||||||
cublas: &cudarc::cublas::CudaBlas,
|
cublas: &cudarc::cublas::CudaBlas,
|
||||||
cfg: StridedBatchedConfig<f16>,
|
cfg: StridedBatchedConfig<f16>,
|
||||||
|
@ -258,3 +258,13 @@ pub fn gemm_reduced_precision_bf16() -> bool {
|
|||||||
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
/// This bool controls whether reduced precision reductions (e.g., with fp16 accumulation type) are
|
||||||
/// allowed with bf16 GEMMs.
|
/// allowed with bf16 GEMMs.
|
||||||
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
pub fn set_gemm_reduced_precision_bf16(_: bool) {}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||||
|
/// allowed with f32 GEMMs.
|
||||||
|
pub fn gemm_reduced_precision_f32() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This bool controls whether reduced precision reductions (e.g., with tf32 accumulation type) are
|
||||||
|
/// allowed with f32 GEMMs.
|
||||||
|
pub fn set_gemm_reduced_precision_f32(_b: bool) {}
|
||||||
|
@ -100,11 +100,11 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
let mut command_buffer_lock = self.command_buffer.write().map_err(MetalError::from)?;
|
||||||
let mut command_buffer = command_buffer_lock.to_owned();
|
let mut command_buffer = command_buffer_lock.to_owned();
|
||||||
let mut index = self
|
let mut index = self
|
||||||
.command_buffer_index
|
.command_buffer_index
|
||||||
.try_write()
|
.write()
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
if *index > self.compute_per_buffer {
|
if *index > self.compute_per_buffer {
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -119,7 +119,7 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) -> Result<()> {
|
pub fn wait_until_completed(&self) -> Result<()> {
|
||||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
let mut command_buffer = self.command_buffer.write().map_err(MetalError::from)?;
|
||||||
match command_buffer.status() {
|
match command_buffer.status() {
|
||||||
metal::MTLCommandBufferStatus::Committed
|
metal::MTLCommandBufferStatus::Committed
|
||||||
| metal::MTLCommandBufferStatus::Scheduled
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
@ -179,7 +179,7 @@ impl MetalDevice {
|
|||||||
size,
|
size,
|
||||||
MTLResourceOptions::StorageModeManaged,
|
MTLResourceOptions::StorageModeManaged,
|
||||||
);
|
);
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
let subbuffers = buffers
|
let subbuffers = buffers
|
||||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||||
.or_insert(vec![]);
|
.or_insert(vec![]);
|
||||||
@ -232,7 +232,7 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn drop_unused_buffers(&self) -> Result<()> {
|
fn drop_unused_buffers(&self) -> Result<()> {
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
for subbuffers in buffers.values_mut() {
|
for subbuffers in buffers.values_mut() {
|
||||||
let newbuffers = subbuffers
|
let newbuffers = subbuffers
|
||||||
.iter()
|
.iter()
|
||||||
@ -251,7 +251,7 @@ impl MetalDevice {
|
|||||||
option: MTLResourceOptions,
|
option: MTLResourceOptions,
|
||||||
_name: &str,
|
_name: &str,
|
||||||
) -> Result<Arc<Buffer>> {
|
) -> Result<Arc<Buffer>> {
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
let mut buffers = self.buffers.write().map_err(MetalError::from)?;
|
||||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||||
// Cloning also ensures we increment the strong count
|
// Cloning also ensures we increment the strong count
|
||||||
return Ok(b.clone());
|
return Ok(b.clone());
|
||||||
|
@ -6,7 +6,7 @@ use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
|||||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
use std::sync::{Arc, Mutex, PoisonError, RwLock, TryLockError};
|
||||||
|
|
||||||
mod device;
|
mod device;
|
||||||
pub use device::{DeviceId, MetalDevice};
|
pub use device::{DeviceId, MetalDevice};
|
||||||
@ -36,6 +36,12 @@ impl<T> From<TryLockError<T>> for MetalError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> From<PoisonError<T>> for MetalError {
|
||||||
|
fn from(p: PoisonError<T>) -> Self {
|
||||||
|
MetalError::LockError(LockError::Poisoned(p.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum MetalError {
|
pub enum MetalError {
|
||||||
|
@ -349,6 +349,30 @@ impl MmapedSafetensors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct SliceSafetensors<'a> {
|
||||||
|
safetensors: SafeTensors<'a>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SliceSafetensors<'a> {
|
||||||
|
/// Creates a wrapper around a binary buffer and deserialize the safetensors header.
|
||||||
|
pub fn new(buffer: &'a [u8]) -> Result<Self> {
|
||||||
|
let safetensors = safetensors::SafeTensors::deserialize(buffer)?;
|
||||||
|
Ok(Self { safetensors })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||||
|
self.safetensors.tensor(name)?.load(dev)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||||
|
self.safetensors.tensors()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||||
|
Ok(self.safetensors.tensor(name)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct BufferedSafetensors {
|
pub struct BufferedSafetensors {
|
||||||
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
safetensors: yoke::Yoke<SafeTensors_<'static>, Vec<u8>>,
|
||||||
}
|
}
|
||||||
|
@ -235,4 +235,66 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the values on `self` using values from `src`. The copy starts at the specified
|
||||||
|
/// `offset` for the target dimension `dim` on `self`.
|
||||||
|
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||||
|
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||||
|
///
|
||||||
|
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||||
|
/// back-propagation.
|
||||||
|
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||||
|
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||||
|
if !self.is_contiguous() || !src.is_contiguous() {
|
||||||
|
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
|
||||||
|
}
|
||||||
|
if self.dtype() != src.dtype() {
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: self.dtype(),
|
||||||
|
rhs: src.dtype(),
|
||||||
|
op: "slice-set",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.device().location() != src.device().location() {
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: self.device().location(),
|
||||||
|
rhs: src.device().location(),
|
||||||
|
op: "slice-set",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.rank() != src.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: self.rank(),
|
||||||
|
got: src.rank(),
|
||||||
|
shape: self.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
|
||||||
|
if dim_idx == dim && *v2 + offset > *v1 {
|
||||||
|
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
|
||||||
|
}
|
||||||
|
if dim_idx != dim && v1 != v2 {
|
||||||
|
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let block_size: usize = src.dims().iter().skip(1 + dim).product();
|
||||||
|
let d1: usize = src.dims().iter().take(dim).product();
|
||||||
|
let d2 = block_size * src.dims()[dim];
|
||||||
|
let dst_o = self.layout().start_offset() + offset * block_size;
|
||||||
|
let src_o = src.layout().start_offset();
|
||||||
|
src.storage().copy2d(
|
||||||
|
&mut self.storage_mut(),
|
||||||
|
d1,
|
||||||
|
d2,
|
||||||
|
/* src_s */ d2,
|
||||||
|
/* dst_s */ block_size * self.dims()[dim],
|
||||||
|
src_o,
|
||||||
|
dst_o,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,31 @@
|
|||||||
use candle_core::{DType, Result, Tensor};
|
use candle_core::{DType, Result, Tensor};
|
||||||
|
|
||||||
|
struct TmpFile(std::path::PathBuf);
|
||||||
|
|
||||||
|
impl TmpFile {
|
||||||
|
fn create(base: &str) -> TmpFile {
|
||||||
|
let filename = std::env::temp_dir().join(format!(
|
||||||
|
"candle-{}-{}-{:?}",
|
||||||
|
base,
|
||||||
|
std::process::id(),
|
||||||
|
std::thread::current().id(),
|
||||||
|
));
|
||||||
|
TmpFile(filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::convert::AsRef<std::path::Path> for TmpFile {
|
||||||
|
fn as_ref(&self) -> &std::path::Path {
|
||||||
|
self.0.as_path()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TmpFile {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
std::fs::remove_file(&self.0).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn npy() -> Result<()> {
|
fn npy() -> Result<()> {
|
||||||
let npy = Tensor::read_npy("tests/test.npy")?;
|
let npy = Tensor::read_npy("tests/test.npy")?;
|
||||||
@ -22,3 +48,24 @@ fn npz() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn safetensors() -> Result<()> {
|
||||||
|
use candle_core::safetensors::Load;
|
||||||
|
|
||||||
|
let tmp_file = TmpFile::create("st");
|
||||||
|
let t = Tensor::arange(0f32, 24f32, &candle_core::Device::Cpu)?;
|
||||||
|
t.save_safetensors("t", &tmp_file)?;
|
||||||
|
// Load from file.
|
||||||
|
let st = candle_core::safetensors::load(&tmp_file, &candle_core::Device::Cpu)?;
|
||||||
|
let t2 = st.get("t").unwrap();
|
||||||
|
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0f32);
|
||||||
|
// Load from bytes.
|
||||||
|
let bytes = std::fs::read(tmp_file)?;
|
||||||
|
let st = candle_core::safetensors::SliceSafetensors::new(&bytes)?;
|
||||||
|
let t2 = st.get("t").unwrap().load(&candle_core::Device::Cpu);
|
||||||
|
let diff = (&t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0f32);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -665,6 +665,30 @@ fn broadcast(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn slice_set(device: &Device) -> Result<()> {
|
||||||
|
let (b, h, max_t, d) = (2, 4, 7, 3);
|
||||||
|
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
|
||||||
|
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
|
||||||
|
cache.slice_set(&tensor, 2, 0)?;
|
||||||
|
let cache_t = cache.narrow(2, 0, 4)?;
|
||||||
|
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
cache.slice_set(&tensor, 2, 1)?;
|
||||||
|
let cache_t = cache.narrow(2, 1, 4)?;
|
||||||
|
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
|
||||||
|
cache.slice_set(&ones, 2, 6)?;
|
||||||
|
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
let diff = (cache.narrow(2, 6, 1)? - 1.)?
|
||||||
|
.abs()?
|
||||||
|
.sum_all()?
|
||||||
|
.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn cat(device: &Device) -> Result<()> {
|
fn cat(device: &Device) -> Result<()> {
|
||||||
// 1D
|
// 1D
|
||||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||||
@ -1146,6 +1170,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
|||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||||
|
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
|
||||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||||
|
@ -193,6 +193,9 @@ struct Args {
|
|||||||
/// The model to use.
|
/// The model to use.
|
||||||
#[arg(long, default_value = "2b")]
|
#[arg(long, default_value = "2b")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -270,7 +273,7 @@ fn main() -> Result<()> {
|
|||||||
DType::F32
|
DType::F32
|
||||||
};
|
};
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(args.use_flash_attn, &config, vb)?;
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
19
candle-examples/examples/gte-qwen/README.md
Normal file
19
candle-examples/examples/gte-qwen/README.md
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# gte-Qwen1.5-7B-instruct
|
||||||
|
|
||||||
|
gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.
|
||||||
|
|
||||||
|
- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
|
||||||
|
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*
|
||||||
|
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
Automatically download the model from the HuggingFace hub:
|
||||||
|
```bash
|
||||||
|
$ cargo run --example gte-qwen --release
|
||||||
|
```
|
||||||
|
|
||||||
|
or, load the model from a local directory:
|
||||||
|
```bash
|
||||||
|
cargo run --example gte-qwen --release --features cuda -- --local-repo /path/to/gte_Qwen1.5-7B-instruct/
|
||||||
|
```
|
178
candle-examples/examples/gte-qwen/main.rs
Normal file
178
candle-examples/examples/gte-qwen/main.rs
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::qwen2::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::{
|
||||||
|
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
|
||||||
|
Tokenizer,
|
||||||
|
};
|
||||||
|
|
||||||
|
// gte-Qwen1.5-7B-instruct use EOS token as padding token
|
||||||
|
const EOS_TOKEN: &str = "<|endoftext|>";
|
||||||
|
const EOS_TOKEN_ID: u32 = 151643;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "Alibaba-NLP/gte-Qwen1.5-7B-instruct")]
|
||||||
|
model_id: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
local_repo: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ConfigFiles {
|
||||||
|
pub config: std::path::PathBuf,
|
||||||
|
pub tokenizer: std::path::PathBuf,
|
||||||
|
pub weights: Vec<std::path::PathBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loading the model from the HuggingFace Hub. Network access is required.
|
||||||
|
fn load_from_hub(model_id: &str, revision: &str) -> Result<ConfigFiles> {
|
||||||
|
let api = Api::new()?;
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id.to_string(),
|
||||||
|
RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
));
|
||||||
|
Ok(ConfigFiles {
|
||||||
|
config: repo.get("config.json")?,
|
||||||
|
tokenizer: repo.get("tokenizer.json")?,
|
||||||
|
weights: candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loading the model from a local directory.
|
||||||
|
fn load_from_local(local_path: &str) -> Result<ConfigFiles> {
|
||||||
|
let local_path = std::path::PathBuf::from(local_path);
|
||||||
|
let weight_path = local_path.join("model.safetensors.index.json");
|
||||||
|
let json: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(weight_path)?)?;
|
||||||
|
let weight_map = match json.get("weight_map") {
|
||||||
|
Some(serde_json::Value::Object(map)) => map,
|
||||||
|
Some(_) => panic!("`weight map` is not a map"),
|
||||||
|
None => panic!("`weight map` not found"),
|
||||||
|
};
|
||||||
|
let mut safetensors_files = std::collections::HashSet::new();
|
||||||
|
for value in weight_map.values() {
|
||||||
|
safetensors_files.insert(
|
||||||
|
value
|
||||||
|
.as_str()
|
||||||
|
.expect("Weight files should be parsed as strings"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let safetensors_paths = safetensors_files
|
||||||
|
.iter()
|
||||||
|
.map(|v| local_path.join(v))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Ok(ConfigFiles {
|
||||||
|
config: local_path.join("config.json"),
|
||||||
|
tokenizer: local_path.join("tokenizer.json"),
|
||||||
|
weights: safetensors_paths,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fetch the model. Do this offline if local path provided.
|
||||||
|
println!("Fetching model files...");
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config_files = match args.local_repo {
|
||||||
|
Some(local_path) => load_from_local(&local_path)?,
|
||||||
|
None => load_from_hub(&args.model_id, &args.revision)?,
|
||||||
|
};
|
||||||
|
println!("Model file retrieved in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
// Inputs will be padded to the longest sequence in the batch.
|
||||||
|
let padding = PaddingParams {
|
||||||
|
strategy: PaddingStrategy::BatchLongest,
|
||||||
|
direction: PaddingDirection::Left,
|
||||||
|
pad_to_multiple_of: None,
|
||||||
|
pad_id: EOS_TOKEN_ID,
|
||||||
|
pad_type_id: 0,
|
||||||
|
pad_token: String::from(EOS_TOKEN),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Tokenizer setup
|
||||||
|
let mut tokenizer = Tokenizer::from_file(config_files.tokenizer).map_err(E::msg)?;
|
||||||
|
tokenizer.with_padding(Some(padding));
|
||||||
|
|
||||||
|
// Model initialization
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let config: Config = serde_json::from_slice(&std::fs::read(config_files.config)?)?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&config_files.weights, dtype, &device)? };
|
||||||
|
let mut model = Model::new(&config, vb)?;
|
||||||
|
println!("Model loaded in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
// Encode the queries and the targets
|
||||||
|
let instruct = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ";
|
||||||
|
let documents = vec![
|
||||||
|
format!("{instruct}how much protein should a female eat{EOS_TOKEN}"),
|
||||||
|
format!("{instruct}summit define{EOS_TOKEN}"),
|
||||||
|
format!("As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.{EOS_TOKEN}"),
|
||||||
|
format!("Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.{EOS_TOKEN}"),
|
||||||
|
];
|
||||||
|
let encoded = tokenizer.encode_batch(documents, true).map_err(E::msg)?;
|
||||||
|
let tokens: Vec<&[u32]> = encoded.iter().map(|x| x.get_ids()).collect();
|
||||||
|
let tokens = Tensor::new(tokens, &device)?;
|
||||||
|
let mask: Vec<&[u32]> = encoded.iter().map(|x| x.get_attention_mask()).collect();
|
||||||
|
let mask = Tensor::new(mask, &device)?;
|
||||||
|
|
||||||
|
// Inference
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
let logits = model.forward(&tokens, 0, Some(&mask))?;
|
||||||
|
|
||||||
|
// Extract the last hidden states as embeddings since inputs are padded left.
|
||||||
|
let (_, seq_len, _) = logits.dims3()?;
|
||||||
|
let embd = logits
|
||||||
|
.narrow(1, seq_len - 1, 1)?
|
||||||
|
.squeeze(1)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
// Calculate the relativity scores. Note the embeddings should be normalized.
|
||||||
|
let norm = embd.broadcast_div(&embd.sqr()?.sum_keepdim(1)?.sqrt()?)?;
|
||||||
|
let scores = norm.narrow(0, 0, 2)?.matmul(&norm.narrow(0, 2, 2)?.t()?)?;
|
||||||
|
|
||||||
|
// Print the results
|
||||||
|
println!("Embedding done in {:?}", start_gen.elapsed());
|
||||||
|
println!("Scores: {:?}", scores.to_vec2::<f32>()?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -90,6 +90,9 @@ struct Args {
|
|||||||
/// The model size to use.
|
/// The model size to use.
|
||||||
#[arg(long, default_value = "phi-3b")]
|
#[arg(long, default_value = "phi-3b")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -213,7 +216,13 @@ fn main() -> anyhow::Result<()> {
|
|||||||
);
|
);
|
||||||
match args.which {
|
match args.which {
|
||||||
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
||||||
Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?),
|
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
|
||||||
|
1,
|
||||||
|
args.use_flash_attn,
|
||||||
|
model,
|
||||||
|
&mut file,
|
||||||
|
&device,
|
||||||
|
)?),
|
||||||
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
|
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
|
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
||||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
|
@ -39,7 +39,7 @@ struct Args {
|
|||||||
|
|
||||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||||
/// mask, positive makes the mask more selective.
|
/// mask, positive makes the mask more selective.
|
||||||
#[arg(long, default_value_t = 0.)]
|
#[arg(long, allow_hyphen_values = true, default_value_t = 0.)]
|
||||||
threshold: f32,
|
threshold: f32,
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||||
|
if (smem_size >= 48 * 1024) {
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||||
|
}
|
||||||
// int ctas_per_sm;
|
// int ctas_per_sm;
|
||||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||||
|
@ -139,7 +139,9 @@ impl FlashAttn {
|
|||||||
|
|
||||||
let elem_count = out_shape.elem_count();
|
let elem_count = out_shape.elem_count();
|
||||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
|
let softmax_lse = dev
|
||||||
|
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
|
||||||
|
.w()?;
|
||||||
|
|
||||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
|
101
candle-nn/src/kv_cache.rs
Normal file
101
candle-nn/src/kv_cache.rs
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
use candle::{DType, Device, Result, Shape, Tensor};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Cache {
|
||||||
|
all_data: Tensor,
|
||||||
|
dim: usize,
|
||||||
|
current_seq_len: usize,
|
||||||
|
max_seq_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Cache {
|
||||||
|
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
||||||
|
dim: D,
|
||||||
|
shape: S,
|
||||||
|
dtype: DType,
|
||||||
|
dev: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let shape = shape.into();
|
||||||
|
let dim = dim.to_index(&shape, "kv-cache")?;
|
||||||
|
let max_seq_len = shape.dims()[dim];
|
||||||
|
let all_data = Tensor::zeros(shape, dtype, dev)?;
|
||||||
|
Ok(Self {
|
||||||
|
all_data,
|
||||||
|
dim,
|
||||||
|
current_seq_len: 0,
|
||||||
|
max_seq_len,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dim(&self) -> usize {
|
||||||
|
self.dim
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn current_seq_len(&self) -> usize {
|
||||||
|
self.current_seq_len
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max_seq_len(&self) -> usize {
|
||||||
|
self.max_seq_len
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn all_data(&self) -> &Tensor {
|
||||||
|
&self.all_data
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn current_data(&self) -> Result<Tensor> {
|
||||||
|
self.all_data.narrow(self.dim, 0, self.current_seq_len)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
||||||
|
let seq_len = src.dim(self.dim)?;
|
||||||
|
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||||
|
candle::bail!(
|
||||||
|
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
||||||
|
self.current_seq_len,
|
||||||
|
self.max_seq_len
|
||||||
|
)
|
||||||
|
}
|
||||||
|
self.all_data
|
||||||
|
.slice_set(src, self.dim, self.current_seq_len)?;
|
||||||
|
self.current_seq_len += seq_len;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct KvCache {
|
||||||
|
k: Cache,
|
||||||
|
v: Cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl KvCache {
|
||||||
|
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
||||||
|
dim: D,
|
||||||
|
shape: S,
|
||||||
|
dtype: DType,
|
||||||
|
dev: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let shape = shape.into();
|
||||||
|
let dim = dim.to_index(&shape, "kv-cache")?;
|
||||||
|
let k = Cache::new(dim, &shape, dtype, dev)?;
|
||||||
|
let v = Cache::new(dim, &shape, dtype, dev)?;
|
||||||
|
Ok(Self { k, v })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn k(&self) -> Result<Tensor> {
|
||||||
|
self.k.current_data()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn v(&self) -> Result<Tensor> {
|
||||||
|
self.v.current_data()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
|
self.k.append(k)?;
|
||||||
|
self.v.append(v)?;
|
||||||
|
let k = self.k.current_data()?;
|
||||||
|
let v = self.v.current_data()?;
|
||||||
|
Ok((k, v))
|
||||||
|
}
|
||||||
|
}
|
@ -6,6 +6,7 @@ pub mod encoding;
|
|||||||
pub mod func;
|
pub mod func;
|
||||||
pub mod group_norm;
|
pub mod group_norm;
|
||||||
pub mod init;
|
pub mod init;
|
||||||
|
pub mod kv_cache;
|
||||||
pub mod layer_norm;
|
pub mod layer_norm;
|
||||||
pub mod linear;
|
pub mod linear;
|
||||||
pub mod loss;
|
pub mod loss;
|
||||||
|
@ -422,6 +422,32 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'a> SimpleBackend for candle::safetensors::SliceSafetensors<'a> {
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
s: Shape,
|
||||||
|
name: &str,
|
||||||
|
_: crate::Init,
|
||||||
|
dtype: DType,
|
||||||
|
dev: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
|
||||||
|
if tensor.shape() != &s {
|
||||||
|
Err(candle::Error::UnexpectedShape {
|
||||||
|
msg: format!("shape mismatch for {name}"),
|
||||||
|
expected: s,
|
||||||
|
got: tensor.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.get(name).is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
impl<'a> VarBuilder<'a> {
|
||||||
/// Initializes a `VarBuilder` using a custom backend.
|
/// Initializes a `VarBuilder` using a custom backend.
|
||||||
///
|
///
|
||||||
@ -481,12 +507,18 @@ impl<'a> VarBuilder<'a> {
|
|||||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
/// Initializes a `VarBuilder` from a binary buffer in the safetensor format.
|
||||||
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes a `VarBuilder` from a binary slice in the safetensor format.
|
||||||
|
pub fn from_slice_safetensors(data: &'a [u8], dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
|
let tensors = candle::safetensors::SliceSafetensors::new(data)?;
|
||||||
|
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let npz = candle::npy::NpzTensors::new(p)?;
|
let npz = candle::npy::NpzTensors::new(p)?;
|
||||||
|
@ -971,7 +971,7 @@ pub fn simple_eval(
|
|||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
"RandomUniform" => {
|
random_type @ ("RandomUniform" | "RandomNormal") => {
|
||||||
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
||||||
// type by
|
// type by
|
||||||
// default
|
// default
|
||||||
@ -979,36 +979,42 @@ pub fn simple_eval(
|
|||||||
Ok(dt) => match dtype(dt) {
|
Ok(dt) => match dtype(dt) {
|
||||||
Some(DType::U8 | DType::U32 | DType::I64) => {
|
Some(DType::U8 | DType::U32 | DType::I64) => {
|
||||||
bail!(
|
bail!(
|
||||||
"unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
|
"unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}",
|
||||||
node.name
|
node.name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
Some(dt) => dt,
|
Some(dt) => dt,
|
||||||
None => {
|
None => {
|
||||||
bail!(
|
bail!(
|
||||||
"unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
|
"unsupported 'dtype' value {dt:?} for {random_type} {}",
|
||||||
node.name
|
node.name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
bail!(
|
bail!(
|
||||||
"unsupported 'dtype' value {dt:?} for RandomUniform {}",
|
"unsupported 'dtype' value {dt:?} for {random_type} {}",
|
||||||
node.name
|
node.name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
|
|
||||||
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
|
|
||||||
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
|
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
|
||||||
if seed.is_some() {
|
if seed.is_some() {
|
||||||
bail!("seed for RandomUniform is currently not supported")
|
bail!("seed for {random_type} is currently not supported")
|
||||||
};
|
};
|
||||||
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
|
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| *x as usize)
|
.map(|x| *x as usize)
|
||||||
.collect();
|
.collect();
|
||||||
let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
|
let output = if random_type == "RandomUniform" {
|
||||||
|
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
|
||||||
|
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
|
||||||
|
Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?
|
||||||
|
} else {
|
||||||
|
let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0);
|
||||||
|
let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0);
|
||||||
|
Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)?
|
||||||
|
};
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
|
@ -2020,6 +2020,150 @@ fn test_random_uniform() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// "RandomNormal"
|
||||||
|
#[test]
|
||||||
|
fn test_random_normal() -> Result<()> {
|
||||||
|
test(vec![3, 2, 1, 4], None, None)?;
|
||||||
|
test(vec![2, 2, 2, 2], Some(-10.0), None)?;
|
||||||
|
test(vec![2, 2, 2, 2], None, Some(10.0))?;
|
||||||
|
test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?;
|
||||||
|
|
||||||
|
fn test(shape: Vec<i64>, mean: Option<f32>, scale: Option<f32>) -> Result<()> {
|
||||||
|
let att_mean = AttributeProto {
|
||||||
|
name: "mean".to_string(),
|
||||||
|
ref_attr_name: "mean".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "mean".to_string(),
|
||||||
|
r#type: 1, // FLOAT
|
||||||
|
f: mean.unwrap_or(0.0),
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_scale = AttributeProto {
|
||||||
|
name: "scale".to_string(),
|
||||||
|
ref_attr_name: "scale".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "scale".to_string(),
|
||||||
|
r#type: 1, // FLOAT
|
||||||
|
f: scale.unwrap_or(1.0),
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_shape = AttributeProto {
|
||||||
|
name: "shape".to_string(),
|
||||||
|
ref_attr_name: "shape".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "shape".to_string(),
|
||||||
|
r#type: 7, // INTS
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: shape,
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_dtype = AttributeProto {
|
||||||
|
name: "dtype".to_string(),
|
||||||
|
ref_attr_name: "dtype".to_string(),
|
||||||
|
i: 11, // DOUBLE
|
||||||
|
doc_string: "dtype".to_string(),
|
||||||
|
r#type: 2, // INT
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let attrs = {
|
||||||
|
let mut mut_attrs = vec![att_shape, att_dtype];
|
||||||
|
if mean.is_some() {
|
||||||
|
mut_attrs.push(att_mean);
|
||||||
|
}
|
||||||
|
if scale.is_some() {
|
||||||
|
mut_attrs.push(att_scale);
|
||||||
|
}
|
||||||
|
mut_attrs
|
||||||
|
};
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "RandomNormal".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: attrs,
|
||||||
|
input: vec![],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let data = z.flatten_all()?.to_vec1::<f64>()?;
|
||||||
|
|
||||||
|
// test if values are unique
|
||||||
|
for (i, a) in data.iter().enumerate() {
|
||||||
|
for (j, b) in data.iter().enumerate() {
|
||||||
|
if i == j {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
assert_ne!(a, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// "Range"
|
// "Range"
|
||||||
#[test]
|
#[test]
|
||||||
fn test_range() -> Result<()> {
|
fn test_range() -> Result<()> {
|
||||||
|
@ -73,13 +73,6 @@ struct RotaryEmbedding {
|
|||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let last_dim = xs.dim(D::Minus1)?;
|
|
||||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
|
||||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
|
||||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
let dim = cfg.head_dim;
|
let dim = cfg.head_dim;
|
||||||
@ -94,7 +87,6 @@ impl RotaryEmbedding {
|
|||||||
.to_dtype(dtype)?
|
.to_dtype(dtype)?
|
||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?,
|
sin: freqs.sin()?,
|
||||||
cos: freqs.cos()?,
|
cos: freqs.cos()?,
|
||||||
@ -110,10 +102,8 @@ impl RotaryEmbedding {
|
|||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
|
||||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -163,10 +153,16 @@ struct Attention {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
rotary_emb: Arc<RotaryEmbedding>,
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
use_flash_attn: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
use_flash_attn: bool,
|
||||||
|
cfg: &Config,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
let hidden_sz = cfg.hidden_size;
|
let hidden_sz = cfg.hidden_size;
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
@ -188,6 +184,7 @@ impl Attention {
|
|||||||
head_dim,
|
head_dim,
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
|
use_flash_attn,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,7 +228,14 @@ impl Attention {
|
|||||||
let value_states =
|
let value_states =
|
||||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = if self.use_flash_attn {
|
||||||
|
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||||
|
let q = query_states.transpose(1, 2)?;
|
||||||
|
let k = key_states.transpose(1, 2)?;
|
||||||
|
let v = value_states.transpose(1, 2)?;
|
||||||
|
let scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
|
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
|
||||||
|
} else {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
@ -253,6 +257,22 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "flash-attn")]
|
||||||
|
fn flash_attn(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "flash-attn"))]
|
||||||
|
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||||
|
unimplemented!("compile with '--features flash-attn'")
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct DecoderLayer {
|
struct DecoderLayer {
|
||||||
self_attn: Attention,
|
self_attn: Attention,
|
||||||
@ -262,8 +282,13 @@ struct DecoderLayer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl DecoderLayer {
|
impl DecoderLayer {
|
||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(
|
||||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
use_flash_attn: bool,
|
||||||
|
cfg: &Config,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
|
||||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
let input_layernorm =
|
let input_layernorm =
|
||||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
@ -312,7 +337,7 @@ pub struct Model {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let vb_m = vb.pp("model");
|
let vb_m = vb.pp("model");
|
||||||
let embed_tokens =
|
let embed_tokens =
|
||||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
@ -320,7 +345,8 @@ impl Model {
|
|||||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
let vb_l = vb_m.pp("layers");
|
let vb_l = vb_m.pp("layers");
|
||||||
for layer_idx in 0..cfg.num_hidden_layers {
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
let layer =
|
||||||
|
DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
|
||||||
layers.push(layer)
|
layers.push(layer)
|
||||||
}
|
}
|
||||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||||
|
@ -3,9 +3,7 @@ use std::collections::HashMap;
|
|||||||
use candle::quantized::gguf_file;
|
use candle::quantized::gguf_file;
|
||||||
use candle::quantized::QTensor;
|
use candle::quantized::QTensor;
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, RmsNorm};
|
use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm};
|
||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 4096;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct QLinear {
|
struct QLinear {
|
||||||
@ -70,7 +68,8 @@ struct LayerWeights {
|
|||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
neg_inf: Tensor,
|
neg_inf: Tensor,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: KvCache,
|
||||||
|
use_flash_attn: bool,
|
||||||
span_attn: tracing::Span,
|
span_attn: tracing::Span,
|
||||||
span_rot: tracing::Span,
|
span_rot: tracing::Span,
|
||||||
}
|
}
|
||||||
@ -122,23 +121,21 @@ impl LayerWeights {
|
|||||||
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
|
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
|
||||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||||
|
|
||||||
let (k, v) = match &self.kv_cache {
|
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
|
||||||
None => (k.contiguous()?, v.contiguous()?),
|
|
||||||
Some((k_cache, v_cache)) => {
|
|
||||||
if index_pos == 0 {
|
|
||||||
(k.contiguous()?, v.contiguous()?)
|
|
||||||
} else {
|
|
||||||
let k = Tensor::cat(&[k_cache, &k], 2)?;
|
|
||||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
|
||||||
(k.contiguous()?, v.contiguous()?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
self.kv_cache = Some((k.clone(), v.clone()));
|
|
||||||
|
|
||||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||||
|
|
||||||
|
let y = if self.use_flash_attn {
|
||||||
|
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||||
|
let q = q.to_dtype(DType::BF16)?.transpose(1, 2)?;
|
||||||
|
let k = k.to_dtype(DType::BF16)?.transpose(1, 2)?;
|
||||||
|
let v = v.to_dtype(DType::BF16)?.transpose(1, 2)?;
|
||||||
|
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
|
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
} else {
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let att = match mask {
|
let att = match mask {
|
||||||
None => att,
|
None => att,
|
||||||
@ -149,13 +146,30 @@ impl LayerWeights {
|
|||||||
};
|
};
|
||||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
att.matmul(&v.contiguous()?)?
|
||||||
|
};
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||||
let y = self.attn_output.forward(&y)?;
|
let y = self.attn_output.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "flash-attn")]
|
||||||
|
fn flash_attn(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "flash-attn"))]
|
||||||
|
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||||
|
unimplemented!("compile with '--features flash-attn'")
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ModelWeights {
|
pub struct ModelWeights {
|
||||||
tok_embeddings: Embedding,
|
tok_embeddings: Embedding,
|
||||||
@ -169,6 +183,7 @@ pub struct ModelWeights {
|
|||||||
|
|
||||||
fn precomput_freqs_cis(
|
fn precomput_freqs_cis(
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
|
max_seq_len: usize,
|
||||||
freq_base: f32,
|
freq_base: f32,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
) -> Result<(Tensor, Tensor)> {
|
||||||
@ -177,9 +192,9 @@ fn precomput_freqs_cis(
|
|||||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((MAX_SEQ_LEN, 1))?
|
.reshape((max_seq_len, 1))?
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
let cos = idx_theta.cos()?;
|
let cos = idx_theta.cos()?;
|
||||||
let sin = idx_theta.sin()?;
|
let sin = idx_theta.sin()?;
|
||||||
@ -188,6 +203,8 @@ fn precomput_freqs_cis(
|
|||||||
|
|
||||||
impl ModelWeights {
|
impl ModelWeights {
|
||||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||||
|
batch_size: usize,
|
||||||
|
use_flash_attn: bool,
|
||||||
ct: gguf_file::Content,
|
ct: gguf_file::Content,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -202,16 +219,19 @@ impl ModelWeights {
|
|||||||
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
|
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
|
||||||
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
|
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
|
||||||
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
|
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
|
||||||
|
let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize;
|
||||||
|
let head_dim = embedding_length / head_count;
|
||||||
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
|
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
|
||||||
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
|
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
|
||||||
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||||
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
|
let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?;
|
||||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||||
|
|
||||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||||
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
|
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
|
||||||
let output = QLinear::new(&ct, reader, "output", device)?;
|
let output = QLinear::new(&ct, reader, "output", device)?;
|
||||||
|
|
||||||
let mut layers = Vec::with_capacity(block_count);
|
let mut layers = Vec::with_capacity(block_count);
|
||||||
for layer_idx in 0..block_count {
|
for layer_idx in 0..block_count {
|
||||||
let prefix = format!("blk.{layer_idx}");
|
let prefix = format!("blk.{layer_idx}");
|
||||||
@ -232,6 +252,12 @@ impl ModelWeights {
|
|||||||
)?;
|
)?;
|
||||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
|
let kv_cache = KvCache::new(
|
||||||
|
2,
|
||||||
|
(batch_size, head_count_kv, max_seq_len, head_dim),
|
||||||
|
DType::F32,
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
layers.push(LayerWeights {
|
layers.push(LayerWeights {
|
||||||
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
||||||
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
||||||
@ -240,11 +266,12 @@ impl ModelWeights {
|
|||||||
mlp,
|
mlp,
|
||||||
n_head: head_count,
|
n_head: head_count,
|
||||||
n_kv_head: head_count_kv,
|
n_kv_head: head_count_kv,
|
||||||
head_dim: embedding_length / head_count,
|
head_dim,
|
||||||
cos: cos.clone(),
|
cos: cos.clone(),
|
||||||
sin: sin.clone(),
|
sin: sin.clone(),
|
||||||
neg_inf: neg_inf.clone(),
|
neg_inf: neg_inf.clone(),
|
||||||
kv_cache: None,
|
kv_cache,
|
||||||
|
use_flash_attn,
|
||||||
span_attn,
|
span_attn,
|
||||||
span_rot,
|
span_rot,
|
||||||
})
|
})
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
|
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Activation, VarBuilder};
|
use candle_nn::{Activation, VarBuilder};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -250,7 +250,6 @@ pub struct Model {
|
|||||||
embed_tokens: candle_nn::Embedding,
|
embed_tokens: candle_nn::Embedding,
|
||||||
layers: Vec<DecoderLayer>,
|
layers: Vec<DecoderLayer>,
|
||||||
norm: RmsNorm,
|
norm: RmsNorm,
|
||||||
lm_head: Linear,
|
|
||||||
sliding_window: usize,
|
sliding_window: usize,
|
||||||
device: Device,
|
device: Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
@ -269,19 +268,17 @@ impl Model {
|
|||||||
layers.push(layer)
|
layers.push(layer)
|
||||||
}
|
}
|
||||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
embed_tokens,
|
embed_tokens,
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
lm_head,
|
|
||||||
sliding_window: cfg.sliding_window,
|
sliding_window: cfg.sliding_window,
|
||||||
device: vb.device().clone(),
|
device: vb.device().clone(),
|
||||||
dtype: vb.dtype(),
|
dtype: vb.dtype(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_decoder_attention_mask(
|
fn prepare_causal_attention_mask(
|
||||||
&self,
|
&self,
|
||||||
b_size: usize,
|
b_size: usize,
|
||||||
tgt_len: usize,
|
tgt_len: usize,
|
||||||
@ -301,7 +298,7 @@ impl Model {
|
|||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||||
let mask = if seqlen_offset > 0 {
|
let mask = if seqlen_offset > 0 {
|
||||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
|
||||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
} else {
|
} else {
|
||||||
mask
|
mask
|
||||||
@ -310,21 +307,42 @@ impl Model {
|
|||||||
.to_dtype(self.dtype)
|
.to_dtype(self.dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_sz, sql_len) = attn_mask.dims2()?;
|
||||||
|
let mut mask: Vec<Tensor> = vec![];
|
||||||
|
for b in 0..b_sz {
|
||||||
|
mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
|
||||||
|
}
|
||||||
|
let mask = Tensor::cat(&mask, 0)?;
|
||||||
|
let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
|
||||||
|
let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
|
||||||
|
.broadcast_as(mask.shape())?
|
||||||
|
.to_dtype(self.dtype)?;
|
||||||
|
mask.where_cond(&on_true, &on_false)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(
|
||||||
|
&mut self,
|
||||||
|
input_ids: &Tensor,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
let (b_size, seq_len) = input_ids.dims2()?;
|
let (b_size, seq_len) = input_ids.dims2()?;
|
||||||
let attention_mask = if seq_len <= 1 {
|
let attention_mask: Option<Tensor> = match attn_mask {
|
||||||
|
Some(mask) => Some(self.prepare_attention_mask(mask)?),
|
||||||
|
None => {
|
||||||
|
if seq_len <= 1 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
Some(self.prepare_causal_attention_mask(b_size, seq_len, seqlen_offset)?)
|
||||||
Some(mask)
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||||
}
|
}
|
||||||
xs.narrow(1, seq_len - 1, 1)?
|
xs.apply(&self.norm)
|
||||||
.apply(&self.norm)?
|
|
||||||
.apply(&self.lm_head)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
@ -333,3 +351,32 @@ impl Model {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ModelForCausalLM {
|
||||||
|
base_model: Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelForCausalLM {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
|
let base_model = Model::new(cfg, vb)?;
|
||||||
|
Ok(Self {
|
||||||
|
base_model,
|
||||||
|
lm_head,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
|
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||||
|
self.base_model
|
||||||
|
.forward(input_ids, seqlen_offset, None)?
|
||||||
|
.narrow(1, seq_len - 1, 1)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base_model.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -21,7 +21,7 @@ log = { workspace = true }
|
|||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
wav = { workspace = true }
|
hound = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
|
|
||||||
# Wasm specific crates.
|
# Wasm specific crates.
|
||||||
|
@ -345,16 +345,19 @@ impl Decoder {
|
|||||||
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
|
pub fn convert_and_run(&mut self, wav_input: &[u8]) -> anyhow::Result<Vec<Segment>> {
|
||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let mut wav_input = std::io::Cursor::new(wav_input);
|
let mut wav_input = std::io::Cursor::new(wav_input);
|
||||||
let (header, data) = wav::read(&mut wav_input)?;
|
let wav_reader = hound::WavReader::new(&mut wav_input)?;
|
||||||
console_log!("loaded wav data: {header:?}");
|
let spec = wav_reader.spec();
|
||||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
console_log!("loaded wav data: {spec:?}");
|
||||||
|
if spec.sample_rate != m::SAMPLE_RATE as u32 {
|
||||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
|
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE);
|
||||||
}
|
}
|
||||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
let mut data = wav_reader.into_samples::<i16>().collect::<Vec<_>>();
|
||||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
data.truncate(data.len() / spec.channels as usize);
|
||||||
.iter()
|
let mut pcm_data = Vec::with_capacity(data.len());
|
||||||
.map(|v| *v as f32 / 32768.)
|
for d in data.into_iter() {
|
||||||
.collect();
|
let d = d?;
|
||||||
|
pcm_data.push(d as f32 / 32768.)
|
||||||
|
}
|
||||||
console_log!("pcm data loaded {}", pcm_data.len());
|
console_log!("pcm data loaded {}", pcm_data.len());
|
||||||
let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?;
|
let mel = crate::audio::pcm_to_mel(self.model.config(), &pcm_data, &self.mel_filters)?;
|
||||||
let mel_len = mel.len();
|
let mel_len = mel.len();
|
||||||
|
Reference in New Issue
Block a user