mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
10 Commits
mkl_link_f
...
0.9.0-alph
Author | SHA1 | Date | |
---|---|---|---|
cf9d7bf24c | |||
9d31361c4f | |||
648596c073 | |||
d9904a3baf | |||
d6db305829 | |||
b4daa03e59 | |||
9541467d6b | |||
6429609090 | |||
ba473290da | |||
59c26195db |
26
Cargo.toml
26
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.8.4" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.8.4" }
|
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.4" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.8.4" }
|
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.4" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.8.4" }
|
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.8.4" }
|
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.8.4" }
|
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.4.1"
|
hf-hub = "0.4.1"
|
||||||
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
ug = "0.1.0"
|
ug = "0.2.0"
|
||||||
ug-cuda = "0.1.0"
|
ug-cuda = "0.2.0"
|
||||||
ug-metal = "0.1.0"
|
ug-metal = "0.2.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "1.1.1", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
@ -21,7 +21,9 @@ impl BenchDevice for Device {
|
|||||||
Device::Cpu => Ok(()),
|
Device::Cpu => Ok(()),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
return Ok(device.synchronize()?);
|
return Ok(device
|
||||||
|
.synchronize()
|
||||||
|
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||||
}
|
}
|
||||||
|
@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
|
|||||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||||
return Ok(cudnn.clone());
|
return Ok(cudnn.clone());
|
||||||
}
|
}
|
||||||
let c = Cudnn::new(dev.cuda_device());
|
let c = Cudnn::new(dev.cuda_stream());
|
||||||
if let Ok(c) = &c {
|
if let Ok(c) = &c {
|
||||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||||
}
|
}
|
||||||
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
|
|||||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||||
};
|
};
|
||||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||||
unsafe {
|
unsafe {
|
||||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||||
alg,
|
alg,
|
||||||
|
@ -2,8 +2,9 @@ use crate::backend::BackendDevice;
|
|||||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||||
pub use candle_kernels as kernels;
|
pub use candle_kernels as kernels;
|
||||||
pub use cudarc;
|
pub use cudarc;
|
||||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||||
@ -24,10 +25,17 @@ impl DeviceId {
|
|||||||
struct CudaRng(cudarc::curand::CudaRng);
|
struct CudaRng(cudarc::curand::CudaRng);
|
||||||
unsafe impl Send for CudaRng {}
|
unsafe impl Send for CudaRng {}
|
||||||
|
|
||||||
|
pub struct ModuleStore {
|
||||||
|
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct CudaDevice {
|
pub struct CudaDevice {
|
||||||
id: DeviceId,
|
id: DeviceId,
|
||||||
device: Arc<cudarc::driver::CudaDevice>,
|
context: Arc<cudarc::driver::CudaContext>,
|
||||||
|
modules: Arc<std::sync::RwLock<ModuleStore>>,
|
||||||
|
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
|
||||||
|
stream: Arc<cudarc::driver::CudaStream>,
|
||||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
curand: Arc<Mutex<CudaRng>>,
|
curand: Arc<Mutex<CudaRng>>,
|
||||||
}
|
}
|
||||||
@ -39,16 +47,51 @@ impl std::fmt::Debug for CudaDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for CudaDevice {
|
impl std::ops::Deref for CudaDevice {
|
||||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
type Target = Arc<cudarc::driver::CudaStream>;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.device
|
&self.stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CudaFunc {
|
||||||
|
func: CudaFunction,
|
||||||
|
stream: Arc<cudarc::driver::CudaStream>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for CudaFunc {
|
||||||
|
type Target = CudaFunction;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.func
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaFunc {
|
||||||
|
pub fn into_cuda_function(self) -> CudaFunction {
|
||||||
|
self.func
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! builder_arg {
|
||||||
|
($b:ident, $($arg:expr),*) => {
|
||||||
|
$(
|
||||||
|
let __arg = $arg;
|
||||||
|
$b.arg(&__arg);
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaFunc {
|
||||||
|
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
|
||||||
|
self.stream.launch_builder(&self.func)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaDevice {
|
impl CudaDevice {
|
||||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
|
||||||
self.device.clone()
|
self.stream.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
@ -56,7 +99,7 @@ impl CudaDevice {
|
|||||||
&self,
|
&self,
|
||||||
func_name: &'static str,
|
func_name: &'static str,
|
||||||
kernel: ug::lang::ssa::Kernel,
|
kernel: ug::lang::ssa::Kernel,
|
||||||
) -> Result<CudaFunction> {
|
) -> Result<CudaFunc> {
|
||||||
let mut buf = vec![];
|
let mut buf = vec![];
|
||||||
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||||
let cuda_code = String::from_utf8(buf)?;
|
let cuda_code = String::from_utf8(buf)?;
|
||||||
@ -65,12 +108,12 @@ impl CudaDevice {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
||||||
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
|
let module = self.context.load_module(ptx).w()?;
|
||||||
let func = match self.device.get_func("ug", func_name) {
|
let func = module.load_function(func_name).w()?;
|
||||||
Some(func) => func,
|
Ok(CudaFunc {
|
||||||
None => crate::bail!("unknown function ug::{func_name}"),
|
func,
|
||||||
};
|
stream: self.stream.clone(),
|
||||||
Ok(func)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> DeviceId {
|
pub fn id(&self) -> DeviceId {
|
||||||
@ -84,57 +127,84 @@ impl CudaDevice {
|
|||||||
DType::U8 => {
|
DType::U8 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||||
let params = (&data, v as u8, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as u8;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
DType::U32 => {
|
DType::U32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||||
let params = (&data, v as u32, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as u32;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
DType::I64 => {
|
DType::I64 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||||
let params = (&data, v as i64, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as i64;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||||
let params = (&data, bf16::from_f64(v), elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = bf16::from_f64(v);
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||||
let params = (&data, f16::from_f64(v), elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = f16::from_f64(v);
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||||
let params = (&data, v as f32, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as f32;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||||
let params = (&data, v, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -144,38 +214,69 @@ impl CudaDevice {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
pub fn get_or_load_custom_func(
|
||||||
if !self.has_func(module_name, module_name) {
|
&self,
|
||||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
fn_name: &str,
|
||||||
// done once per kernel name.
|
module_name: &str,
|
||||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
ptx: &str,
|
||||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
) -> Result<CudaFunc> {
|
||||||
.map_err(|cuda| CudaError::Load {
|
let ms = self.custom_modules.read().unwrap();
|
||||||
cuda,
|
if let Some(mdl) = ms.get(module_name).as_ref() {
|
||||||
module_name: module_name.to_string(),
|
let func = mdl.load_function(fn_name).w()?;
|
||||||
})
|
return Ok(CudaFunc {
|
||||||
.w()?;
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
self.get_func(module_name, module_name)
|
drop(ms);
|
||||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
let mut ms = self.custom_modules.write().unwrap();
|
||||||
// able to only build the error value if needed.
|
let cuda_module = self.context.load_module(ptx.into()).w()?;
|
||||||
.ok_or(CudaError::MissingKernel {
|
ms.insert(module_name.to_string(), cuda_module.clone());
|
||||||
module_name: module_name.to_string(),
|
let func = cuda_module.load_function(fn_name).w()?;
|
||||||
})
|
Ok(CudaFunc {
|
||||||
.w()
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
|
||||||
|
let ms = self.modules.read().unwrap();
|
||||||
|
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
|
||||||
|
let func = mdl.load_function(fn_name).w()?;
|
||||||
|
return Ok(CudaFunc {
|
||||||
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
drop(ms);
|
||||||
|
let mut ms = self.modules.write().unwrap();
|
||||||
|
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
|
||||||
|
ms.mdls[mdl.index()] = Some(cuda_module.clone());
|
||||||
|
let func = cuda_module.load_function(fn_name).w()?;
|
||||||
|
Ok(CudaFunc {
|
||||||
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaDevice {
|
impl CudaDevice {
|
||||||
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
|
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
let stream = context.new_stream().w()?;
|
||||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||||
|
let module_store = ModuleStore {
|
||||||
|
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: DeviceId::new(),
|
id: DeviceId::new(),
|
||||||
device,
|
context,
|
||||||
|
stream,
|
||||||
blas: Arc::new(blas),
|
blas: Arc::new(blas),
|
||||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||||
|
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -184,14 +285,21 @@ impl BackendDevice for CudaDevice {
|
|||||||
type Storage = CudaStorage;
|
type Storage = CudaStorage;
|
||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
let stream = context.default_stream();
|
||||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||||
|
let module_store = ModuleStore {
|
||||||
|
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: DeviceId::new(),
|
id: DeviceId::new(),
|
||||||
device,
|
context,
|
||||||
|
stream,
|
||||||
blas: Arc::new(blas),
|
blas: Arc::new(blas),
|
||||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||||
|
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,13 +307,13 @@ impl BackendDevice for CudaDevice {
|
|||||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||||
// state will be identical and the same random numbers will be generated.
|
// state will be identical and the same random numbers will be generated.
|
||||||
let mut curand = self.curand.lock().unwrap();
|
let mut curand = self.curand.lock().unwrap();
|
||||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
crate::DeviceLocation::Cuda {
|
crate::DeviceLocation::Cuda {
|
||||||
gpu_id: self.device.ordinal(),
|
gpu_id: self.context.ordinal(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -373,31 +481,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
let slice = match T::cpu_storage_ref(s) {
|
let slice = match T::cpu_storage_ref(s) {
|
||||||
CpuStorageRef::U8(storage) => {
|
CpuStorageRef::U8(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::U32(storage) => {
|
CpuStorageRef::U32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::I64(storage) => {
|
CpuStorageRef::I64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::BF16(storage) => {
|
CpuStorageRef::BF16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F16(storage) => {
|
CpuStorageRef::F16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F32(storage) => {
|
CpuStorageRef::F32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F64(storage) => {
|
CpuStorageRef::F64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -410,31 +518,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorage::U32(storage) => {
|
CpuStorage::U32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::I64(storage) => {
|
CpuStorage::I64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorage::BF16(storage) => {
|
CpuStorage::BF16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F16(storage) => {
|
CpuStorage::F16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F64(storage) => {
|
CpuStorage::F64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -447,31 +555,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorage::U32(storage) => {
|
CpuStorage::U32(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::I64(storage) => {
|
CpuStorage::I64(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorage::BF16(storage) => {
|
CpuStorage::BF16(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F16(storage) => {
|
CpuStorage::F16(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F64(storage) => {
|
CpuStorage::F64(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -482,7 +590,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
fn synchronize(&self) -> Result<()> {
|
||||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
self.stream.synchronize().map_err(crate::Error::wrap)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -396,7 +396,10 @@ impl UgIOp1 {
|
|||||||
{
|
{
|
||||||
let device = device.as_cuda_device()?;
|
let device = device.as_cuda_device()?;
|
||||||
let func = device.compile(name, kernel)?;
|
let func = device.compile(name, kernel)?;
|
||||||
Ok(Self { name, func })
|
Ok(Self {
|
||||||
|
name,
|
||||||
|
func: func.into_cuda_function(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
{
|
{
|
||||||
@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 {
|
|||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
||||||
use crate::cuda_backend::WrapErr;
|
use crate::cuda_backend::WrapErr;
|
||||||
use cudarc::driver::LaunchAsync;
|
use cudarc::driver::PushKernelArg;
|
||||||
|
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
|
let stream = sto.device.cuda_stream();
|
||||||
// TODO: support more dtypes.
|
// TODO: support more dtypes.
|
||||||
let sto = sto.as_cuda_slice::<f32>()?;
|
let sto = sto.as_cuda_slice::<f32>()?;
|
||||||
let sto = match layout.contiguous_offsets() {
|
let sto = match layout.contiguous_offsets() {
|
||||||
None => crate::bail!("input has to be contiguous"),
|
None => crate::bail!("input has to be contiguous"),
|
||||||
Some((o1, o2)) => sto.slice(o1..o2),
|
Some((o1, o2)) => sto.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let params = (&sto,);
|
|
||||||
let (g, b) = if elem_count % 32 == 0 {
|
let (g, b) = if elem_count % 32 == 0 {
|
||||||
(elem_count / 32, 32)
|
(elem_count / 32, 32)
|
||||||
} else {
|
} else {
|
||||||
@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 {
|
|||||||
block_dim: (b as u32, 1, 1),
|
block_dim: (b as u32, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
unsafe { self.func.clone().launch(cfg, params) }.w()?;
|
let mut builder = stream.launch_builder(&self.func);
|
||||||
|
builder.arg(&sto);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -816,7 +816,7 @@ impl PthTensors {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `path` - Path to the pth file.
|
/// * `path` - Path to the pth file.
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||||
path: P,
|
path: P,
|
||||||
key: Option<&str>,
|
key: Option<&str>,
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
use super::{GgmlDType, QStorage};
|
||||||
use crate::quantized::k_quants::GgmlType;
|
use crate::quantized::k_quants::GgmlType;
|
||||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||||
use crate::{CudaDevice, CudaStorage, Result};
|
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
|
||||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct PaddedCudaSlice {
|
struct PaddedCudaSlice {
|
||||||
@ -50,19 +50,20 @@ fn quantize_q8_1(
|
|||||||
ky: usize,
|
ky: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let kx = elem_count;
|
let kx = elem_count;
|
||||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let params = (src, dst, kx as i32, kx_padded as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(src);
|
||||||
|
builder.arg(dst);
|
||||||
|
barg!(builder, kx as i32, kx_padded as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,8 +73,6 @@ fn dequantize_f32(
|
|||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
let nb = (elem_count + 255) / 256;
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||||
@ -99,7 +98,7 @@ fn dequantize_f32(
|
|||||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||||
// See e.g.
|
// See e.g.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
@ -110,15 +109,20 @@ fn dequantize_f32(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if is_k {
|
if is_k {
|
||||||
let params = (&data.inner, &dst);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
} else {
|
} else {
|
||||||
let nb32 = match dtype {
|
let nb32 = match dtype {
|
||||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
_ => elem_count / 32,
|
_ => elem_count / 32,
|
||||||
};
|
};
|
||||||
let params = (&data.inner, &dst, nb32 as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
barg!(builder, nb32 as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
}
|
}
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
@ -129,8 +133,6 @@ fn dequantize_f16(
|
|||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
let nb = (elem_count + 255) / 256;
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||||
@ -156,7 +158,7 @@ fn dequantize_f16(
|
|||||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||||
// See e.g.
|
// See e.g.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
@ -167,15 +169,20 @@ fn dequantize_f16(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if is_k {
|
if is_k {
|
||||||
let params = (&data.inner, &dst);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
} else {
|
} else {
|
||||||
let nb32 = match dtype {
|
let nb32 = match dtype {
|
||||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
_ => elem_count / 32,
|
_ => elem_count / 32,
|
||||||
};
|
};
|
||||||
let params = (&data.inner, &dst, nb32 as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
barg!(builder, nb32 as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
}
|
}
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
@ -188,8 +195,6 @@ fn dequantize_mul_mat_vec(
|
|||||||
nrows: usize,
|
nrows: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||||
if data_elems < ncols * nrows {
|
if data_elems < ncols * nrows {
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
@ -210,7 +215,7 @@ fn dequantize_mul_mat_vec(
|
|||||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
@ -219,8 +224,12 @@ fn dequantize_mul_mat_vec(
|
|||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(y);
|
||||||
|
builder.arg(&dst);
|
||||||
|
barg!(builder, ncols as i32, nrows as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,8 +242,6 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
b_size: usize,
|
b_size: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||||
if data_elems < ncols * nrows {
|
if data_elems < ncols * nrows {
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
@ -266,7 +273,7 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let kernel_name = format!("{kernel_name}{b_size}");
|
let kernel_name = format!("{kernel_name}{b_size}");
|
||||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||||
let (nblocks, nwarps) = match b_size {
|
let (nblocks, nwarps) = match b_size {
|
||||||
@ -281,16 +288,18 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&data.inner,
|
builder.arg(&data.inner);
|
||||||
&y_q8_1,
|
builder.arg(&y_q8_1);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
|
barg!(
|
||||||
|
builder,
|
||||||
/* ncols_x */ ncols as i32,
|
/* ncols_x */ ncols as i32,
|
||||||
/* nrows_x */ nrows as i32,
|
/* nrows_x */ nrows as i32,
|
||||||
/* nrows_y */ ncols_padded as i32,
|
/* nrows_y */ ncols_padded as i32,
|
||||||
/* nrows_dst */ nrows as i32,
|
/* nrows_dst */ nrows as i32
|
||||||
);
|
);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -305,8 +314,6 @@ fn mul_mat_via_q8_1(
|
|||||||
y_cols: usize,
|
y_cols: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||||
if data_elems < x_rows * x_cols {
|
if data_elems < x_rows * x_cols {
|
||||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||||
@ -338,7 +345,7 @@ fn mul_mat_via_q8_1(
|
|||||||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (
|
grid_dim: (
|
||||||
@ -350,17 +357,19 @@ fn mul_mat_via_q8_1(
|
|||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
/* vx */ &data.inner,
|
builder.arg(/* vx */ &data.inner);
|
||||||
/* vy */ &y_q8_1,
|
builder.arg(/* vy */ &y_q8_1);
|
||||||
/* dst */ &dst,
|
builder.arg(/* dst */ &dst);
|
||||||
|
barg!(
|
||||||
|
builder,
|
||||||
/* ncols_x */ x_cols as i32,
|
/* ncols_x */ x_cols as i32,
|
||||||
/* nrows_x */ x_rows as i32,
|
/* nrows_x */ x_rows as i32,
|
||||||
/* ncols_y */ y_cols as i32,
|
/* ncols_y */ y_cols as i32,
|
||||||
/* nrows_y */ k_padded as i32,
|
/* nrows_y */ k_padded as i32,
|
||||||
/* nrows_dst */ x_rows as i32,
|
/* nrows_dst */ x_rows as i32
|
||||||
);
|
);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,7 +425,7 @@ impl QCudaStorage {
|
|||||||
|
|
||||||
let buffer = self
|
let buffer = self
|
||||||
.device
|
.device
|
||||||
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
|
.memcpy_dtov(&self.data.inner.slice(..self.data.len))
|
||||||
.w()?;
|
.w()?;
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
let block_len = elem_count / self.dtype.block_size();
|
||||||
@ -449,7 +458,7 @@ impl QCudaStorage {
|
|||||||
// Run the quantization on cpu.
|
// Run the quantization on cpu.
|
||||||
let src = match &src.slice {
|
let src = match &src.slice {
|
||||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||||
self.device.dtoh_sync_copy(data).w()?
|
self.device.memcpy_dtov(data).w()?
|
||||||
}
|
}
|
||||||
_ => crate::bail!("only f32 can be quantized"),
|
_ => crate::bail!("only f32 can be quantized"),
|
||||||
};
|
};
|
||||||
@ -462,7 +471,7 @@ impl QCudaStorage {
|
|||||||
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
||||||
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
|
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
|
||||||
self.device
|
self.device
|
||||||
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
||||||
.w()?;
|
.w()?;
|
||||||
self.data = PaddedCudaSlice {
|
self.data = PaddedCudaSlice {
|
||||||
inner,
|
inner,
|
||||||
@ -599,7 +608,7 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
||||||
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
|
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
|
||||||
device
|
device
|
||||||
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
|
.memcpy_htod(data, &mut inner.slice_mut(..data.len()))
|
||||||
.w()?;
|
.w()?;
|
||||||
Ok(QStorage::Cuda(QCudaStorage {
|
Ok(QStorage::Cuda(QCudaStorage {
|
||||||
data: PaddedCudaSlice {
|
data: PaddedCudaSlice {
|
||||||
@ -624,7 +633,7 @@ mod test {
|
|||||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -634,7 +643,7 @@ mod test {
|
|||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let ncols = 256;
|
let ncols = 256;
|
||||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||||
@ -647,7 +656,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
assert_eq!(vs.len(), 1);
|
assert_eq!(vs.len(), 1);
|
||||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||||
// Q8 means 1/256 precision.
|
// Q8 means 1/256 precision.
|
||||||
@ -662,7 +671,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
assert_eq!(vs.len(), 1);
|
assert_eq!(vs.len(), 1);
|
||||||
assert_eq!(vs[0], 5561851.0);
|
assert_eq!(vs[0], 5561851.0);
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -673,7 +682,7 @@ mod test {
|
|||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let ncols = 256;
|
let ncols = 256;
|
||||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
@ -687,7 +696,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
|
|
||||||
/*
|
/*
|
||||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||||
@ -714,7 +723,7 @@ mod test {
|
|||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
||||||
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
@ -728,7 +737,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ impl ArgSort {
|
|||||||
mod cuda {
|
mod cuda {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::cuda_backend::cudarc::driver::{
|
use crate::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
|
||||||
};
|
};
|
||||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
||||||
use crate::{CudaDevice, WithDType};
|
use crate::{CudaDevice, WithDType};
|
||||||
@ -69,6 +69,8 @@ mod cuda {
|
|||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
_wrap: W,
|
_wrap: W,
|
||||||
) -> Result<S> {
|
) -> Result<S> {
|
||||||
|
use cudarc::driver::PushKernelArg;
|
||||||
|
|
||||||
let slice = match layout.contiguous_offsets() {
|
let slice = match layout.contiguous_offsets() {
|
||||||
None => crate::bail!("input has to be contiguous"),
|
None => crate::bail!("input has to be contiguous"),
|
||||||
Some((o1, o2)) => src.slice(o1..o2),
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
@ -76,20 +78,24 @@ mod cuda {
|
|||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||||
let func = if self.asc {
|
let func = if self.asc {
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
||||||
} else {
|
} else {
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
|
||||||
};
|
};
|
||||||
let ncols = self.last_dim;
|
let ncols = self.last_dim;
|
||||||
let nrows = elem_count / ncols;
|
let nrows = elem_count / ncols;
|
||||||
let ncols_pad = next_power_of_2(ncols);
|
let ncols_pad = next_power_of_2(ncols);
|
||||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
|
||||||
let cfg = LaunchConfig {
|
let cfg = LaunchConfig {
|
||||||
grid_dim: (1, nrows as u32, 1),
|
grid_dim: (1, nrows as u32, 1),
|
||||||
block_dim: (ncols_pad as u32, 1, 1),
|
block_dim: (ncols_pad as u32, 1, 1),
|
||||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||||
};
|
};
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let stream = dev.cuda_stream();
|
||||||
|
let mut builder = stream.launch_builder(&func);
|
||||||
|
let ncols = ncols as i32;
|
||||||
|
let ncols_pad = ncols_pad as i32;
|
||||||
|
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(S::U32(dst))
|
Ok(S::U32(dst))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2580,6 +2580,28 @@ impl Tensor {
|
|||||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a new tensor with the order of elements reversed along the specified dimensions.
|
||||||
|
/// This function makes a copy of the tensor’s data.
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use candle_core::{Tensor, Device};
|
||||||
|
/// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
|
||||||
|
/// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
|
/// let t_flipped = t.flip(&[0])?;
|
||||||
|
/// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
|
||||||
|
let mut result = self.clone();
|
||||||
|
for &dim in dims.iter() {
|
||||||
|
let size = result.dim(dim)?;
|
||||||
|
let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
|
||||||
|
let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
|
||||||
|
result = result.index_select(&indices_tensor, dim)?;
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -24,6 +24,15 @@ macro_rules! test_device {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
|
||||||
|
assert_eq!(t1.shape(), t2.shape());
|
||||||
|
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
|
||||||
|
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
|
||||||
|
let all_equal = eq_tensor.sum_all()?;
|
||||||
|
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||||
let b = 10f32.powi(digits);
|
let b = 10f32.powi(digits);
|
||||||
let t = t.to_vec0::<f32>()?;
|
let t = t.to_vec0::<f32>()?;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#![allow(clippy::approx_constant)]
|
#![allow(clippy::approx_constant)]
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
|
||||||
|
|
||||||
fn simple_grad(device: &Device) -> Result<()> {
|
fn simple_grad(device: &Device) -> Result<()> {
|
||||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||||
@ -505,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_backprop() -> Result<()> {
|
||||||
|
let device = &Device::Cpu;
|
||||||
|
|
||||||
|
// Create a tensor (leaf node) that requires gradients
|
||||||
|
let x = Var::ones((2, 2), DType::F64, device)?;
|
||||||
|
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
|
||||||
|
|
||||||
|
let y = x.matmul(&weights)?;
|
||||||
|
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
|
||||||
|
|
||||||
|
let z = y.flip(&[1])?;
|
||||||
|
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
|
||||||
|
|
||||||
|
let loss = z.sum_all()?;
|
||||||
|
|
||||||
|
let grad_store = loss.backward()?;
|
||||||
|
let grad_x = grad_store.get_id(x.id()).unwrap();
|
||||||
|
|
||||||
|
let flipped_weights = weights.flip(&[1])?;
|
||||||
|
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
|
||||||
|
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
|
||||||
|
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(
|
test_device!(
|
||||||
simple_grad,
|
simple_grad,
|
||||||
simple_grad_cpu,
|
simple_grad_cpu,
|
||||||
|
@ -1682,3 +1682,54 @@ fn pow() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_1d() -> Result<()> {
|
||||||
|
// 1D: [0, 1, 2, 3, 4]
|
||||||
|
let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?;
|
||||||
|
let flipped = t.flip(&[0])?;
|
||||||
|
// Expected: [4, 3, 2, 1, 0]
|
||||||
|
let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_2d() -> Result<()> {
|
||||||
|
// 2D:
|
||||||
|
// [[0, 1, 2],
|
||||||
|
// [3, 4, 5]]
|
||||||
|
let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?;
|
||||||
|
let flipped = t.flip(&[0, 1])?;
|
||||||
|
// Expected:
|
||||||
|
// [[5, 4, 3],
|
||||||
|
// [2, 1, 0]]
|
||||||
|
let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_3d_channels() -> Result<()> {
|
||||||
|
// 3D:
|
||||||
|
// [[[0,1,2],
|
||||||
|
// [3,4,5]],
|
||||||
|
//
|
||||||
|
// [[6,7,8],
|
||||||
|
// [9,10,11]]]
|
||||||
|
let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?;
|
||||||
|
let flipped = t.flip(&[2])?;
|
||||||
|
// Expected:
|
||||||
|
// [[[2,1,0],
|
||||||
|
// [5,4,3]],
|
||||||
|
//
|
||||||
|
// [[8,7,6],
|
||||||
|
// [11,10,9]]]
|
||||||
|
let expected = Tensor::from_vec(
|
||||||
|
vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0],
|
||||||
|
(2, 2, 3),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -72,6 +72,8 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
|||||||
if let parquet::record::Field::Group(subrow) = field {
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
for (_name, field) in subrow.get_column_iter() {
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
if let parquet::record::Field::Bytes(value) = field {
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
// image-rs crate convention is to load in (width, height, channels) order
|
||||||
|
// See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions
|
||||||
let image = image::load_from_memory(value.data()).unwrap();
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
buffer_images.extend(image.to_rgb8().as_raw());
|
buffer_images.extend(image.to_rgb8().as_raw());
|
||||||
}
|
}
|
||||||
@ -81,8 +83,10 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
// Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width)
|
||||||
.to_dtype(DType::U8)?
|
let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.permute((0, 3, 2, 1))?
|
||||||
/ 255.)?;
|
/ 255.)?;
|
||||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||||
Ok((images, labels))
|
Ok((images, labels))
|
||||||
|
13
candle-examples/examples/chatglm/README.md
Normal file
13
candle-examples/examples/chatglm/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# candle-chatglm
|
||||||
|
|
||||||
|
Uses `THUDM/chatglm3-6b` to generate chinese text. Will not generate text for english (usually).
|
||||||
|
|
||||||
|
## Text Generation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example chatglm --release -- --prompt "部署门槛较低等众多优秀特 "
|
||||||
|
|
||||||
|
> 部署门槛较低等众多优秀特 点,使得其成为了一款备受欢迎的AI助手。
|
||||||
|
>
|
||||||
|
> 作为一款人工智能助手,ChatGLM3-6B
|
||||||
|
```
|
42
candle-examples/examples/chinese_clip/README.md
Normal file
42
candle-examples/examples/chinese_clip/README.md
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# candle-chinese-clip
|
||||||
|
|
||||||
|
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||||
|
pairs of images with related texts. This one is trained using in chinese instead of english.
|
||||||
|
|
||||||
|
## Running on cpu
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
|
||||||
|
|
||||||
|
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
>
|
||||||
|
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
|
||||||
|
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||||
|
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
|
||||||
|
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
|
||||||
|
>
|
||||||
|
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
>
|
||||||
|
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
|
||||||
|
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||||
|
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running on metal
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --features metal --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
|
||||||
|
|
||||||
|
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
>
|
||||||
|
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
|
||||||
|
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||||
|
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
|
||||||
|
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
|
||||||
|
>
|
||||||
|
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
>
|
||||||
|
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
|
||||||
|
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||||
|
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
|
||||||
|
```
|
17
candle-examples/examples/convmixer/README.md
Normal file
17
candle-examples/examples/convmixer/README.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# candle-convmixer
|
||||||
|
|
||||||
|
A lightweight CNN architecture that processes image patches similar to a vision transformer, with separate spatial and channel convolutions.
|
||||||
|
|
||||||
|
ConvMixer from [Patches Are All You Need?](https://arxiv.org/pdf/2201.09792) and [ConvMixer](https://github.com/locuslab/convmixer).
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example convmixer --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
> mountain bike, all-terrain bike, off-roader: 61.75%
|
||||||
|
> unicycle, monocycle : 5.73%
|
||||||
|
> moped : 3.66%
|
||||||
|
> bicycle-built-for-two, tandem bicycle, tandem: 3.51%
|
||||||
|
> crash helmet : 0.85%
|
||||||
|
```
|
14
candle-examples/examples/csm/README.md
Normal file
14
candle-examples/examples/csm/README.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Conversational Speech Model (CSM)
|
||||||
|
|
||||||
|
CSM is a speech generation model from Sesame,
|
||||||
|
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
|
||||||
|
|
||||||
|
It can generate a conversational speech between two different speakers.
|
||||||
|
The speakers turn are delimited by the `|` character in the prompt.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example csm --features cuda -r -- \
|
||||||
|
--voices voices.safetensors \
|
||||||
|
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
|
||||||
|
```
|
||||||
|
|
243
candle-examples/examples/csm/main.rs
Normal file
243
candle-examples/examples/csm/main.rs
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::csm::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "1b")]
|
||||||
|
Csm1b,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
/// The prompt to be used for the generation, use a | to separate the speakers.
|
||||||
|
#[arg(long, default_value = "Hey how are you doing today?")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The voices to be used, in safetensors format.
|
||||||
|
#[arg(long)]
|
||||||
|
voices: String,
|
||||||
|
|
||||||
|
/// The output file using the wav format.
|
||||||
|
#[arg(long, default_value = "out.wav")]
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long, default_value_t = 0.7)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "1b")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weights: Option<String>,
|
||||||
|
|
||||||
|
/// The mimi model weight file, in safetensor format.
|
||||||
|
#[arg(long)]
|
||||||
|
mimi_weights: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id,
|
||||||
|
None => {
|
||||||
|
let name = match args.which {
|
||||||
|
Which::Csm1b => "sesame/csm-1b",
|
||||||
|
};
|
||||||
|
name.to_string()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let filenames = match args.weights {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => vec![repo.get("model.safetensors")?],
|
||||||
|
};
|
||||||
|
let tokenizer_filename = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => api
|
||||||
|
.model("meta-llama/Llama-3.2-1B".to_string())
|
||||||
|
.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let mimi_filename = match args.mimi_weights {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("kyutai/mimi".to_string())
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config: Config = match args.config {
|
||||||
|
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||||
|
None => {
|
||||||
|
let config_file = repo.get("config.json")?;
|
||||||
|
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let (mut model, device) = {
|
||||||
|
let dtype = device.bf16_default_to_f32();
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
(model, device)
|
||||||
|
};
|
||||||
|
let mut mimi_model = {
|
||||||
|
use candle_transformers::models::mimi;
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
|
||||||
|
let config = mimi::Config::v0_1(Some(32));
|
||||||
|
mimi::Model::new(config, vb)?
|
||||||
|
};
|
||||||
|
let cb = config.audio_num_codebooks;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let voices = candle::safetensors::load(args.voices, &device)?;
|
||||||
|
let mut lp = candle_transformers::generation::LogitsProcessor::new(
|
||||||
|
args.seed,
|
||||||
|
Some(args.temperature),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let tokens = voices
|
||||||
|
.get("tokens")
|
||||||
|
.expect("no tokens in prompt")
|
||||||
|
.to_dtype(DType::U32)?;
|
||||||
|
let mask = voices.get("mask").expect("no mask in prompt").clone();
|
||||||
|
|
||||||
|
let mut pos = 0;
|
||||||
|
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
|
pos += tokens.dim(1)?;
|
||||||
|
|
||||||
|
let mut all_pcms = vec![];
|
||||||
|
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
|
||||||
|
println!("{prompt:?}");
|
||||||
|
let speaker_idx = turn_idx % 2;
|
||||||
|
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
|
||||||
|
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
|
||||||
|
|
||||||
|
let mut generated_tokens = vec![];
|
||||||
|
loop {
|
||||||
|
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
|
pos += tokens.dim(1)?;
|
||||||
|
let is_done = frame.iter().all(|&x| x == 0);
|
||||||
|
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
|
||||||
|
print!("\rframe {pos}");
|
||||||
|
if is_done {
|
||||||
|
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
|
pos += tokens.dim(1)?;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
generated_tokens.push(tokens.clone());
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
||||||
|
let pcm = mimi_model.decode(&generated_tokens)?;
|
||||||
|
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||||
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||||
|
all_pcms.push(pcm);
|
||||||
|
}
|
||||||
|
let pcm = Tensor::cat(&all_pcms, 0)?;
|
||||||
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
|
println!("writing output file {}", args.out_file);
|
||||||
|
let mut output = std::fs::File::create(args.out_file)?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
17
candle-examples/examples/custom-ops/README.md
Normal file
17
candle-examples/examples/custom-ops/README.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# candle-custom-ops
|
||||||
|
|
||||||
|
This example illustrates how to implement forward and backward passes for custom operations on the CPU and GPU.
|
||||||
|
The custom op in this example implements RMS normalization for the CPU and CUDA.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example custom-ops
|
||||||
|
|
||||||
|
> [[ 0., 1., 2., 3., 4., 5., 6.],
|
||||||
|
> [ 7., 8., 9., 10., 11., 12., 13.]]
|
||||||
|
> Tensor[[2, 7], f32]
|
||||||
|
> [[0.0000, 0.2773, 0.5547, 0.8320, 1.1094, 1.3867, 1.6641],
|
||||||
|
> [0.6864, 0.7845, 0.8825, 0.9806, 1.0786, 1.1767, 1.2748]]
|
||||||
|
> Tensor[[2, 7], f32]
|
||||||
|
```
|
@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm {
|
|||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
|
use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg};
|
||||||
use candle::cuda_backend::WrapErr;
|
use candle::cuda_backend::WrapErr;
|
||||||
let (d1, d2) = layout.shape().dims2()?;
|
let (d1, d2) = layout.shape().dims2()?;
|
||||||
let d1 = d1 as u32;
|
let d1 = d1 as u32;
|
||||||
@ -69,14 +69,18 @@ impl CustomOp1 for LayerNorm {
|
|||||||
};
|
};
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
let func =
|
||||||
let params = (&dst, &slice, self.eps, d1, d2);
|
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||||
let cfg = LaunchConfig {
|
let cfg = LaunchConfig {
|
||||||
grid_dim: (d1, 1, 1),
|
grid_dim: (d1, 1, 1),
|
||||||
block_dim: (d2, 1, 1),
|
block_dim: (d2, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let mut builder = func.builder();
|
||||||
|
builder.arg(&dst);
|
||||||
|
builder.arg(&slice);
|
||||||
|
candle::builder_arg!(builder, self.eps, d1, d2);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
|
|
||||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||||
Ok((dst, layout.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
|
15
candle-examples/examples/efficientnet/README.md
Normal file
15
candle-examples/examples/efficientnet/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# candle-efficientnet
|
||||||
|
|
||||||
|
Demonstrates a Candle implementation of EfficientNet for image classification based on ImageNet classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example efficientnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which b1
|
||||||
|
|
||||||
|
> bicycle-built-for-two, tandem bicycle, tandem: 45.85%
|
||||||
|
> mountain bike, all-terrain bike, off-roader: 30.45%
|
||||||
|
> crash helmet : 2.58%
|
||||||
|
> unicycle, monocycle : 2.21%
|
||||||
|
> tricycle, trike, velocipede: 1.53%
|
||||||
|
```
|
@ -1,3 +1,10 @@
|
|||||||
# candle-falcon
|
# candle-falcon
|
||||||
|
|
||||||
Falcon is a general large language model.
|
Falcon is a general large language model.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
Make sure to include the `--use-f32` flag if using CPU, because there isn't a BFloat16 implementation yet.
|
||||||
|
```
|
||||||
|
cargo run --example falcon --release -- --prompt "Flying monkeys are" --use-f32
|
||||||
|
```
|
@ -12,7 +12,7 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
|
|||||||
|
|
||||||
** Running with ~cpu~
|
** Running with ~cpu~
|
||||||
#+begin_src shell
|
#+begin_src shell
|
||||||
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
|
cargo run --example glm4 --release -- --cpu --prompt "Hello world"
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
** Output Example
|
** Output Example
|
||||||
|
11
candle-examples/examples/llama/README.md
Normal file
11
candle-examples/examples/llama/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# candle-llama
|
||||||
|
|
||||||
|
Candle implementations of various Llama based architectures.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example llama -- --prompt "Machine learning is " --which v32-3b-instruct
|
||||||
|
|
||||||
|
> Machine learning is the part of computer science which deals with the development of algorithms and
|
||||||
|
```
|
@ -21,7 +21,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
(self.d_model + 15) / 16
|
self.d_model.div_ceil(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_conv(&self) -> usize {
|
fn d_conv(&self) -> usize {
|
||||||
|
@ -12,6 +12,6 @@ would only work for inference.
|
|||||||
## Running the example
|
## Running the example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
$ cargo run --example mamba --release -- --prompt "Mamba is the"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -18,21 +18,19 @@ I know you are waiting for me. I will go through the forest, I will go through t
|
|||||||
mountain. I cannot stay far from you any longer.</s>
|
mountain. I cannot stay far from you any longer.</s>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Changing model and language pairs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example marian-mt --release -- --text "hello, how are you." --which base --language-pair en-zh
|
||||||
|
|
||||||
|
你好,你好吗?
|
||||||
|
```
|
||||||
|
|
||||||
## Generating the tokenizer.json files
|
## Generating the tokenizer.json files
|
||||||
|
|
||||||
You can use the following script to generate the `tokenizer.json` config files
|
The tokenizer for each `marian-mt` model was trained independently,
|
||||||
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
meaning each new model needs unique tokenizer encoders and decoders.
|
||||||
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
You can use the `./python/convert_slow_tokenizer.py` script in this directory to generate
|
||||||
directory.
|
the `tokenizer.json` config files from the hf-hub repos.
|
||||||
|
The script requires all the packages in `./python/requirements.txt` or `./python/uv.lock`
|
||||||
```python
|
to be installed, and has only been tested for `python 3.12.7`.
|
||||||
from convert_slow_tokenizer import MarianConverter
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
|
||||||
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
|
||||||
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
|
||||||
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
|
||||||
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
|
||||||
```
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -20,6 +20,22 @@ enum Which {
|
|||||||
Big,
|
Big,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum LanguagePair {
|
||||||
|
#[value(name = "fr-en")]
|
||||||
|
FrEn,
|
||||||
|
#[value(name = "en-zh")]
|
||||||
|
EnZh,
|
||||||
|
#[value(name = "en-hi")]
|
||||||
|
EnHi,
|
||||||
|
#[value(name = "en-es")]
|
||||||
|
EnEs,
|
||||||
|
#[value(name = "en-fr")]
|
||||||
|
EnFr,
|
||||||
|
#[value(name = "en-ru")]
|
||||||
|
EnRu,
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Maybe add support for the conditional prompt.
|
// TODO: Maybe add support for the conditional prompt.
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -36,6 +52,10 @@ struct Args {
|
|||||||
#[arg(long, default_value = "big")]
|
#[arg(long, default_value = "big")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
|
// Choose which language pair to use
|
||||||
|
#[arg(long, default_value = "fr-en")]
|
||||||
|
language_pair: LanguagePair,
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
/// Run on CPU rather than on GPU.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
@ -53,21 +73,43 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let config = match args.which {
|
let config = match (args.which, args.language_pair) {
|
||||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
(Which::Base, LanguagePair::FrEn) => marian::Config::opus_mt_fr_en(),
|
||||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
(Which::Big, LanguagePair::FrEn) => marian::Config::opus_mt_tc_big_fr_en(),
|
||||||
|
(Which::Base, LanguagePair::EnZh) => marian::Config::opus_mt_en_zh(),
|
||||||
|
(Which::Base, LanguagePair::EnHi) => marian::Config::opus_mt_en_hi(),
|
||||||
|
(Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(),
|
||||||
|
(Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(),
|
||||||
|
(Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(),
|
||||||
|
(Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"),
|
||||||
|
};
|
||||||
|
let tokenizer_default_repo = match args.language_pair {
|
||||||
|
LanguagePair::FrEn => "lmz/candle-marian",
|
||||||
|
LanguagePair::EnZh
|
||||||
|
| LanguagePair::EnHi
|
||||||
|
| LanguagePair::EnEs
|
||||||
|
| LanguagePair::EnFr
|
||||||
|
| LanguagePair::EnRu => "KeighBee/candle-marian",
|
||||||
};
|
};
|
||||||
let tokenizer = {
|
let tokenizer = {
|
||||||
let tokenizer = match args.tokenizer {
|
let tokenizer = match args.tokenizer {
|
||||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
None => {
|
None => {
|
||||||
let name = match args.which {
|
let filename = match (args.which, args.language_pair) {
|
||||||
Which::Base => "tokenizer-marian-base-fr.json",
|
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-fr.json",
|
||||||
Which::Big => "tokenizer-marian-fr.json",
|
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-fr.json",
|
||||||
|
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json",
|
||||||
|
(Which::Big, lp) => {
|
||||||
|
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Api::new()?
|
Api::new()?
|
||||||
.model("lmz/candle-marian".to_string())
|
.model(tokenizer_default_repo.to_string())
|
||||||
.get(name)?
|
.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
@ -77,13 +119,21 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let tokenizer = match args.tokenizer_dec {
|
let tokenizer = match args.tokenizer_dec {
|
||||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
None => {
|
None => {
|
||||||
let name = match args.which {
|
let filename = match (args.which, args.language_pair) {
|
||||||
Which::Base => "tokenizer-marian-base-en.json",
|
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-en.json",
|
||||||
Which::Big => "tokenizer-marian-en.json",
|
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-zh.json",
|
||||||
|
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-hi.json",
|
||||||
|
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json",
|
||||||
|
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json",
|
||||||
|
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json",
|
||||||
|
(Which::Big, lp) => {
|
||||||
|
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Api::new()?
|
Api::new()?
|
||||||
.model("lmz/candle-marian".to_string())
|
.model(tokenizer_default_repo.to_string())
|
||||||
.get(name)?
|
.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
@ -94,18 +144,48 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let vb = {
|
let vb = {
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
None => match args.which {
|
None => {
|
||||||
Which::Base => Api::new()?
|
let api = Api::new()?;
|
||||||
.repo(hf_hub::Repo::with_revision(
|
let api = match (args.which, args.language_pair) {
|
||||||
|
(Which::Base, LanguagePair::FrEn) => api.repo(hf_hub::Repo::with_revision(
|
||||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||||
hf_hub::RepoType::Model,
|
hf_hub::RepoType::Model,
|
||||||
"refs/pr/4".to_string(),
|
"refs/pr/4".to_string(),
|
||||||
))
|
)),
|
||||||
.get("model.safetensors")?,
|
(Which::Big, LanguagePair::FrEn) => {
|
||||||
Which::Big => Api::new()?
|
api.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
}
|
||||||
.get("model.safetensors")?,
|
(Which::Base, LanguagePair::EnZh) => api.repo(hf_hub::Repo::with_revision(
|
||||||
},
|
"Helsinki-NLP/opus-mt-en-zh".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/13".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnHi) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-hi".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/3".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnEs) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-es".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/4".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnFr) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-fr".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/9".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnRu) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-ru".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/7".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Big, lp) => {
|
||||||
|
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.convert_slow_tokenizer import SpmConverter, requires_backends, import_protobuf
|
||||||
|
|
||||||
|
class MarianConverter(SpmConverter):
|
||||||
|
def __init__(self, *args, index: int = 0):
|
||||||
|
requires_backends(self, "protobuf")
|
||||||
|
|
||||||
|
super(SpmConverter, self).__init__(*args)
|
||||||
|
|
||||||
|
# from .utils import sentencepiece_model_pb2 as model_pb2
|
||||||
|
model_pb2 = import_protobuf()
|
||||||
|
|
||||||
|
m = model_pb2.ModelProto()
|
||||||
|
print(self.original_tokenizer.spm_files)
|
||||||
|
with open(self.original_tokenizer.spm_files[index], "rb") as f:
|
||||||
|
m.ParseFromString(f.read())
|
||||||
|
self.proto = m
|
||||||
|
print(self.original_tokenizer)
|
||||||
|
#with open(self.original_tokenizer.vocab_path, "r") as f:
|
||||||
|
dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
|
||||||
|
with open(dir_path / "vocab.json", "r") as f:
|
||||||
|
import json
|
||||||
|
self._vocab = json.load(f)
|
||||||
|
|
||||||
|
if self.proto.trainer_spec.byte_fallback:
|
||||||
|
if not getattr(self, "handle_byte_fallback", None):
|
||||||
|
warnings.warn(
|
||||||
|
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||||
|
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
||||||
|
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
||||||
|
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
||||||
|
)
|
||||||
|
|
||||||
|
def vocab(self, proto):
|
||||||
|
vocab_size = max(self._vocab.values()) + 1
|
||||||
|
vocab = [("<NIL>", -100) for _ in range(vocab_size)]
|
||||||
|
for piece in proto.pieces:
|
||||||
|
try:
|
||||||
|
index = self._vocab[piece.piece]
|
||||||
|
except Exception:
|
||||||
|
print(f"Ignored missing piece {piece.piece}")
|
||||||
|
vocab[index] = (piece.piece, piece.score)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||||
|
fast_tokenizer.save("tokenizer-marian-base-fr.json")
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||||
|
fast_tokenizer.save("tokenizer-marian-base-en.json")
|
22
candle-examples/examples/marian-mt/python/requirements.txt
Normal file
22
candle-examples/examples/marian-mt/python/requirements.txt
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
certifi==2025.1.31
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
click==8.1.8
|
||||||
|
filelock==3.18.0
|
||||||
|
fsspec==2025.3.2
|
||||||
|
huggingface-hub==0.30.1
|
||||||
|
idna==3.10
|
||||||
|
joblib==1.4.2
|
||||||
|
numpy==2.2.4
|
||||||
|
packaging==24.2
|
||||||
|
protobuf==6.30.2
|
||||||
|
pyyaml==6.0.2
|
||||||
|
regex==2024.11.6
|
||||||
|
requests==2.32.3
|
||||||
|
sacremoses==0.1.1
|
||||||
|
safetensors==0.5.3
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
tokenizers==0.21.1
|
||||||
|
tqdm==4.67.1
|
||||||
|
transformers==4.50.3
|
||||||
|
typing-extensions==4.13.0
|
||||||
|
urllib3==2.3.0
|
@ -13,6 +13,6 @@ Note that the current candle implementation suffers from some limitations as of
|
|||||||
## Run an example
|
## Run an example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example metavoice --release -- \\
|
cargo run --example metavoice --release -- \
|
||||||
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
|
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
|
||||||
```
|
```
|
||||||
|
16
candle-examples/examples/mnist-training/README.md
Normal file
16
candle-examples/examples/mnist-training/README.md
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# candle-mnist-training
|
||||||
|
|
||||||
|
Training a 2 layer MLP on mnist in Candle.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example mnist-training --features candle-datasets
|
||||||
|
|
||||||
|
> train-images: [60000, 784]
|
||||||
|
> train-labels: [60000]
|
||||||
|
> test-images: [10000, 784]
|
||||||
|
> test-labels: [10000]
|
||||||
|
> 1 train loss: 2.30265 test acc: 68.08%
|
||||||
|
> 2 train loss: 1.50815 test acc: 60.77%
|
||||||
|
```
|
@ -12,7 +12,7 @@ $ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jp
|
|||||||
|
|
||||||
Now you can run Moondream from the `candle-examples` crate:
|
Now you can run Moondream from the `candle-examples` crate:
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
|
$ cargo run --example moondream --release -- --prompt "Describe the people behind the bikers?" --image "candle-examples/examples/yolo-v8/assets/bike.jpg"
|
||||||
|
|
||||||
avavx: false, neon: true, simd128: false, f16c: false
|
avavx: false, neon: true, simd128: false, f16c: false
|
||||||
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
||||||
|
20
candle-examples/examples/musicgen/README.md
Normal file
20
candle-examples/examples/musicgen/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-musicgen
|
||||||
|
|
||||||
|
Candle implementation of musicgen from [Simple and Controllable Music Generation](https://arxiv.org/pdf/2306.05284).
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example musicgen -- --prompt "90s rock song with loud guitars and heavy drums"
|
||||||
|
|
||||||
|
> tokens: [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7, 1]
|
||||||
|
> Tensor[dims 1, 13; u32]
|
||||||
|
> [[[ 0.0902, 0.1256, -0.0585, ..., 0.1057, -0.5141, -0.4675],
|
||||||
|
> [ 0.1972, -0.0268, -0.3368, ..., -0.0495, -0.3597, -0.3940],
|
||||||
|
> [-0.0855, -0.0007, 0.2225, ..., -0.2804, -0.5360, -0.2436],
|
||||||
|
> ...
|
||||||
|
> [ 0.0515, 0.0235, -0.3855, ..., -0.4728, -0.6858, -0.2923],
|
||||||
|
> [-0.3728, -0.1442, -0.1179, ..., -0.4388, -0.0287, -0.3242],
|
||||||
|
> [ 0.0163, 0.0012, -0.0020, ..., 0.0142, 0.0173, -0.0103]]]
|
||||||
|
> Tensor[[1, 13, 768], f32]
|
||||||
|
```
|
20
candle-examples/examples/quantized-phi/README.md
Normal file
20
candle-examples/examples/quantized-phi/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-quantized-phi
|
||||||
|
|
||||||
|
Candle implementation of various quantized Phi models.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example quantized-phi --release -- --prompt "The best thing about coding in rust is "
|
||||||
|
|
||||||
|
> - it's memory safe (without you having to worry too much)
|
||||||
|
> - the borrow checker is really smart and will catch your mistakes for free, making them show up as compile errors instead of segfaulting in runtime.
|
||||||
|
>
|
||||||
|
> This alone make me prefer using rust over c++ or go, python/Cython etc.
|
||||||
|
>
|
||||||
|
> The major downside I can see now:
|
||||||
|
> - it's slower than other languages (viz: C++) and most importantly lack of libraries to leverage existing work done by community in that language. There are so many useful machine learning libraries available for c++, go, python etc but none for Rust as far as I am aware of on the first glance.
|
||||||
|
> - there aren't a lot of production ready projects which also makes it very hard to start new one (given my background)
|
||||||
|
>
|
||||||
|
> Another downside:
|
||||||
|
```
|
@ -27,6 +27,8 @@ enum Which {
|
|||||||
W2_7b,
|
W2_7b,
|
||||||
#[value(name = "72b")]
|
#[value(name = "72b")]
|
||||||
W2_72b,
|
W2_72b,
|
||||||
|
#[value(name = "deepseekr1-qwen7b")]
|
||||||
|
DeepseekR1Qwen7B,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -102,6 +104,7 @@ impl Args {
|
|||||||
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
|
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
|
||||||
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
|
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
|
||||||
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
|
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
|
||||||
|
Which::DeepseekR1Qwen7B => "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
||||||
};
|
};
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
api.get("tokenizer.json")?
|
api.get("tokenizer.json")?
|
||||||
@ -135,6 +138,11 @@ impl Args {
|
|||||||
"qwen2-72b-instruct-q4_0.gguf",
|
"qwen2-72b-instruct-q4_0.gguf",
|
||||||
"main",
|
"main",
|
||||||
),
|
),
|
||||||
|
Which::DeepseekR1Qwen7B => (
|
||||||
|
"unsloth/DeepSeek-R1-Distill-Qwen-7B-GGUF",
|
||||||
|
"DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf",
|
||||||
|
"main",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
api.repo(hf_hub::Repo::with_revision(
|
api.repo(hf_hub::Repo::with_revision(
|
||||||
@ -211,11 +219,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let tokenizer = args.tokenizer()?;
|
let tokenizer = args.tokenizer()?;
|
||||||
let mut tos = TokenOutputStream::new(tokenizer);
|
let mut tos = TokenOutputStream::new(tokenizer);
|
||||||
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
let prompt_str = args
|
||||||
let prompt_str = format!(
|
.prompt
|
||||||
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
.clone()
|
||||||
prompt_str
|
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||||
);
|
|
||||||
|
let prompt_str = match args.which {
|
||||||
|
Which::DeepseekR1Qwen7B => format!("<|User|>{prompt_str}<|Assistant|>"),
|
||||||
|
_ => format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n"),
|
||||||
|
};
|
||||||
print!("formatted instruct prompt: {}", &prompt_str);
|
print!("formatted instruct prompt: {}", &prompt_str);
|
||||||
let tokens = tos
|
let tokens = tos
|
||||||
.tokenizer()
|
.tokenizer()
|
||||||
@ -260,7 +272,13 @@ fn main() -> anyhow::Result<()> {
|
|||||||
print!("{t}");
|
print!("{t}");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
|
|
||||||
|
let eos_token = match args.which {
|
||||||
|
Which::DeepseekR1Qwen7B => "<|end▁of▁sentence|>",
|
||||||
|
_ => "<|im_end|>",
|
||||||
|
};
|
||||||
|
|
||||||
|
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
let mut sampled = 0;
|
let mut sampled = 0;
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# candle-quantized-t5
|
# candle-quantized-t5
|
||||||
|
|
||||||
|
Candle implementation for quantizing and running T5 translation models.
|
||||||
|
|
||||||
## Seq2Seq example
|
## Seq2Seq example
|
||||||
|
|
||||||
This example uses a quantized version of the t5 model.
|
This example uses a quantized version of the t5 model.
|
||||||
|
@ -75,6 +75,8 @@ enum Which {
|
|||||||
SmolLM2_360MInstruct,
|
SmolLM2_360MInstruct,
|
||||||
#[value(name = "SmoLM2-1.7B-Instruct")]
|
#[value(name = "SmoLM2-1.7B-Instruct")]
|
||||||
SmolLM2_1BInstruct,
|
SmolLM2_1BInstruct,
|
||||||
|
#[value(name = "deepseekr1-llama8b")]
|
||||||
|
DeepseekR1Llama8b,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -94,7 +96,8 @@ impl Which {
|
|||||||
| Self::L8b
|
| Self::L8b
|
||||||
| Self::Phi3
|
| Self::Phi3
|
||||||
| Self::SmolLM2_1BInstruct
|
| Self::SmolLM2_1BInstruct
|
||||||
| Self::SmolLM2_360MInstruct => false,
|
| Self::SmolLM2_360MInstruct
|
||||||
|
| Self::DeepseekR1Llama8b => false,
|
||||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||||
// same way. Starling is a fine tuned version of OpenChat.
|
// same way. Starling is a fine tuned version of OpenChat.
|
||||||
Self::OpenChat35
|
Self::OpenChat35
|
||||||
@ -132,7 +135,8 @@ impl Which {
|
|||||||
| Self::L8b
|
| Self::L8b
|
||||||
| Self::SmolLM2_1BInstruct
|
| Self::SmolLM2_1BInstruct
|
||||||
| Self::SmolLM2_360MInstruct
|
| Self::SmolLM2_360MInstruct
|
||||||
| Self::Phi3 => false,
|
| Self::Phi3
|
||||||
|
| Self::DeepseekR1Llama8b => false,
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -160,11 +164,41 @@ impl Which {
|
|||||||
| Self::L8b
|
| Self::L8b
|
||||||
| Self::SmolLM2_1BInstruct
|
| Self::SmolLM2_1BInstruct
|
||||||
| Self::SmolLM2_360MInstruct
|
| Self::SmolLM2_360MInstruct
|
||||||
| Self::Phi3 => false,
|
| Self::Phi3
|
||||||
|
| Self::DeepseekR1Llama8b => false,
|
||||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_deepseek(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::L7b
|
||||||
|
| Self::L13b
|
||||||
|
| Self::L70b
|
||||||
|
| Self::L7bChat
|
||||||
|
| Self::L13bChat
|
||||||
|
| Self::L70bChat
|
||||||
|
| Self::L7bCode
|
||||||
|
| Self::L13bCode
|
||||||
|
| Self::L34bCode
|
||||||
|
| Self::Leo7b
|
||||||
|
| Self::Leo13b
|
||||||
|
| Self::Mixtral
|
||||||
|
| Self::MixtralInstruct
|
||||||
|
| Self::Mistral7b
|
||||||
|
| Self::Mistral7bInstruct
|
||||||
|
| Self::Mistral7bInstructV02
|
||||||
|
| Self::Zephyr7bAlpha
|
||||||
|
| Self::Zephyr7bBeta
|
||||||
|
| Self::L8b
|
||||||
|
| Self::SmolLM2_1BInstruct
|
||||||
|
| Self::SmolLM2_360MInstruct
|
||||||
|
| Self::Phi3
|
||||||
|
| Self::OpenChat35
|
||||||
|
| Self::Starling7bAlpha => false,
|
||||||
|
Self::DeepseekR1Llama8b => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
fn tokenizer_repo(&self) -> &'static str {
|
fn tokenizer_repo(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::L7b
|
Self::L7b
|
||||||
@ -191,6 +225,7 @@ impl Which {
|
|||||||
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
||||||
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||||
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||||
|
Self::DeepseekR1Llama8b => "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -363,6 +398,10 @@ impl Args {
|
|||||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
|
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
|
||||||
"smollm2-1.7b-instruct-q4_k_m.gguf",
|
"smollm2-1.7b-instruct-q4_k_m.gguf",
|
||||||
),
|
),
|
||||||
|
Which::DeepseekR1Llama8b => (
|
||||||
|
"unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
|
||||||
|
"DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let revision = if self.which == Which::Phi3 {
|
let revision = if self.which == Which::Phi3 {
|
||||||
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
|
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
|
||||||
@ -477,6 +516,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L8b
|
| Which::L8b
|
||||||
| Which::SmolLM2_1BInstruct
|
| Which::SmolLM2_1BInstruct
|
||||||
| Which::SmolLM2_360MInstruct
|
| Which::SmolLM2_360MInstruct
|
||||||
|
| Which::DeepseekR1Llama8b
|
||||||
| Which::Phi3 => 1,
|
| Which::Phi3 => 1,
|
||||||
Which::Mixtral
|
Which::Mixtral
|
||||||
| Which::MixtralInstruct
|
| Which::MixtralInstruct
|
||||||
@ -530,6 +570,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
} else if args.which.is_mistral() {
|
} else if args.which.is_mistral() {
|
||||||
format!("[INST] {prompt} [/INST]")
|
format!("[INST] {prompt} [/INST]")
|
||||||
|
} else if args.which.is_deepseek() {
|
||||||
|
format!("<|User|>{prompt}<|Assistant|>")
|
||||||
} else {
|
} else {
|
||||||
prompt
|
prompt
|
||||||
}
|
}
|
||||||
@ -597,6 +639,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let eos_token = match args.which {
|
let eos_token = match args.which {
|
||||||
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
|
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
|
||||||
Which::L8b => "<|end_of_text|>",
|
Which::L8b => "<|end_of_text|>",
|
||||||
|
Which::DeepseekR1Llama8b => "<|end▁of▁sentence|>",
|
||||||
_ => match args.which.is_open_chat() {
|
_ => match args.which.is_open_chat() {
|
||||||
true => "<|end_of_turn|>",
|
true => "<|end_of_turn|>",
|
||||||
false => "</s>",
|
false => "</s>",
|
||||||
|
@ -2,6 +2,11 @@
|
|||||||
|
|
||||||
Reinforcement Learning examples for candle.
|
Reinforcement Learning examples for candle.
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> uv is not currently compatible with pyo3 as of 2025/3/28.
|
||||||
|
|
||||||
|
## System wide python
|
||||||
|
|
||||||
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
||||||
Python package with:
|
Python package with:
|
||||||
```bash
|
```bash
|
||||||
|
@ -7,7 +7,7 @@ probabilities for the top-5 classes.
|
|||||||
## Running an example
|
## Running an example
|
||||||
|
|
||||||
```
|
```
|
||||||
$ cargo run --example resnet --release -- --image tiger.jpg
|
$ cargo run --example resnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
model built
|
model built
|
||||||
|
@ -10,9 +10,11 @@ If you want you can use the example images from this [pull request][pr], downloa
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# run the image classification task
|
# run the image classification task
|
||||||
cargo run --example segformer classify <path-to-image>
|
cargo run --example segformer classify candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
# run the segmentation task
|
# run the segmentation task
|
||||||
cargo run --example segformer segment <path-to-image>
|
cargo run --example segformer segment candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Example output for classification:
|
Example output for classification:
|
||||||
|
@ -14,8 +14,8 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example segment-anything --release -- \
|
cargo run --example segment-anything --release -- \
|
||||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
--image candle-examples/examples/yolo-v8/assets/bike.jpg \
|
||||||
--use-tiny
|
--use-tiny \
|
||||||
--point 0.6,0.6 --point 0.6,0.55
|
--point 0.6,0.6 --point 0.6,0.55
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmo
|
|||||||
|
|
||||||
### Running an example
|
### Running an example
|
||||||
```
|
```
|
||||||
$ cargo run --features cuda -r --example siglip -
|
$ cargo run --features cuda -r --example siglip
|
||||||
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
|
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,14 @@ This example uses the models available in the hugging face [onnx-community/siler
|
|||||||
|
|
||||||
## Running the example
|
## Running the example
|
||||||
|
|
||||||
|
### using arecord
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### using SoX
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ rec -t raw -r 48000 -b 16 -c 1 -e signed-integer - trim 0 5 | sox -t raw -r 48000 -b 16 -c 1 -e signed-integer - -t raw -r 16000 -b 16 -c 1 -e signed-integer - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
||||||
|
```
|
||||||
|
15
candle-examples/examples/starcoder2/README.md
Normal file
15
candle-examples/examples/starcoder2/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# candle-starcoder2
|
||||||
|
|
||||||
|
Candle implementation of Star Coder 2 family of code generation model from [StarCoder 2 and The Stack v2: The Next Generation](https://arxiv.org/pdf/2402.19173).
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example starcoder2 -- --prompt "write a recursive fibonacci function in python "
|
||||||
|
|
||||||
|
> # that returns the nth number in the sequence.
|
||||||
|
>
|
||||||
|
> def fib(n):
|
||||||
|
> if n
|
||||||
|
|
||||||
|
```
|
@ -10,7 +10,7 @@ Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. T
|
|||||||
are downloaded from the hub on the first run.
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?"
|
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?" --which 1.5b
|
||||||
|
|
||||||
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
|
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
|
||||||
> Tensor[[1, 1024], f32]
|
> Tensor[[1, 1024], f32]
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# candle-t5
|
# candle-t5
|
||||||
|
|
||||||
|
Candle implementations of the T5 family of translation models.
|
||||||
|
|
||||||
## Encoder-decoder example:
|
## Encoder-decoder example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -7,7 +7,7 @@ The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main
|
|||||||
You can run the example with the following command:
|
You can run the example with the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
|
cargo run --example vgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which vgg13
|
||||||
```
|
```
|
||||||
|
|
||||||
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).
|
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).
|
||||||
|
@ -7,8 +7,8 @@ probabilities for the top-5 classes.
|
|||||||
|
|
||||||
## Running an example
|
## Running an example
|
||||||
|
|
||||||
```
|
```bash
|
||||||
$ cargo run --example vit --release -- --image tiger.jpg
|
$ cargo run --example vit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
model built
|
model built
|
||||||
|
15
candle-examples/examples/whisper-microphone/README.md
Normal file
15
candle-examples/examples/whisper-microphone/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# candle-whisper-microphone
|
||||||
|
|
||||||
|
Whisper implementation using microphone as input.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example whisper-microphone --features microphone
|
||||||
|
|
||||||
|
> transcribing audio...
|
||||||
|
> 480256 160083
|
||||||
|
> language_token: None
|
||||||
|
> 0.0s -- 30.0s: Hello, hello, I don't know if this is working, but You know, how long did I make this?
|
||||||
|
> 480256 160085
|
||||||
|
```
|
13
candle-examples/examples/yi/README.md
Normal file
13
candle-examples/examples/yi/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# candle-yi
|
||||||
|
|
||||||
|
Candle implentations of the Yi family of bilingual (English, Chinese) LLMs.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example yi -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
|
> python
|
||||||
|
> print("Hello World")
|
||||||
|
>
|
||||||
|
```
|
32
candle-examples/examples/yolo-v3/README.md
Normal file
32
candle-examples/examples/yolo-v3/README.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# candle-yolo-v3:
|
||||||
|
|
||||||
|
Candle implementation of Yolo-V3 for object detection.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example yolo-v3 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
> generated predictions Tensor[dims 10647, 85; f32]
|
||||||
|
> person: Bbox { xmin: 46.362198, ymin: 72.177, xmax: 135.92522, ymax: 339.8356, confidence: 0.99705493, data: () }
|
||||||
|
> person: Bbox { xmin: 137.25645, ymin: 67.58148, xmax: 216.90437, ymax: 333.80756, confidence: 0.9898516, data: () }
|
||||||
|
> person: Bbox { xmin: 245.7842, ymin: 82.76726, xmax: 316.79053, ymax: 337.21613, confidence: 0.9884322, data: () }
|
||||||
|
> person: Bbox { xmin: 207.52783, ymin: 61.815224, xmax: 266.77884, ymax: 307.92606, confidence: 0.9860648, data: () }
|
||||||
|
> person: Bbox { xmin: 11.457404, ymin: 60.335564, xmax: 34.39357, ymax: 187.7714, confidence: 0.9545012, data: () }
|
||||||
|
> person: Bbox { xmin: 251.88353, ymin: 11.235481, xmax: 286.56607, ymax: 92.54697, confidence: 0.8439807, data: () }
|
||||||
|
> person: Bbox { xmin: -0.44309902, ymin: 55.486923, xmax: 13.160354, ymax: 184.09705, confidence: 0.8266243, data: () }
|
||||||
|
> person: Bbox { xmin: 317.40826, ymin: 55.39501, xmax: 370.6704, ymax: 153.74887, confidence: 0.7327442, data: () }
|
||||||
|
> person: Bbox { xmin: 370.02835, ymin: 66.120224, xmax: 404.22824, ymax: 142.09691, confidence: 0.7265741, data: () }
|
||||||
|
> person: Bbox { xmin: 250.36511, ymin: 57.349842, xmax: 280.06335, ymax: 116.29384, confidence: 0.709422, data: () }
|
||||||
|
> person: Bbox { xmin: 32.573215, ymin: 66.66239, xmax: 50.49056, ymax: 173.42068, confidence: 0.6998766, data: () }
|
||||||
|
> person: Bbox { xmin: 131.72215, ymin: 63.946213, xmax: 166.66151, ymax: 241.52773, confidence: 0.64457536, data: () }
|
||||||
|
> person: Bbox { xmin: 407.42416, ymin: 49.106407, xmax: 415.24307, ymax: 84.7134, confidence: 0.5955802, data: () }
|
||||||
|
> person: Bbox { xmin: 51.650482, ymin: 64.4985, xmax: 67.40904, ymax: 106.952385, confidence: 0.5196007, data: () }
|
||||||
|
> bicycle: Bbox { xmin: 160.10031, ymin: 183.90837, xmax: 200.86832, ymax: 398.609, confidence: 0.9623588, data: () }
|
||||||
|
> bicycle: Bbox { xmin: 66.570915, ymin: 192.56966, xmax: 112.06765, ymax: 369.28497, confidence: 0.9174347, data: () }
|
||||||
|
> bicycle: Bbox { xmin: 258.2856, ymin: 197.04532, xmax: 298.43106, ymax: 364.8627, confidence: 0.6851388, data: () }
|
||||||
|
> bicycle: Bbox { xmin: 214.0034, ymin: 175.76498, xmax: 252.45158, ymax: 356.53818, confidence: 0.67071193, data: () }
|
||||||
|
> motorbike: Bbox { xmin: 318.23938, ymin: 95.22487, xmax: 369.9743, ymax: 213.46263, confidence: 0.96691036, data: () }
|
||||||
|
> motorbike: Bbox { xmin: 367.46417, ymin: 100.07982, xmax: 394.9981, ymax: 174.6545, confidence: 0.9185384, data: () }
|
||||||
|
> writing "candle-examples/examples/yolo-v8/assets/bike.pp.jpg"
|
||||||
|
```
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.4" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -88,6 +88,7 @@ impl FlashAttn {
|
|||||||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream();
|
||||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||||
if alibi_slopes.dtype() != DType::F32 {
|
if alibi_slopes.dtype() != DType::F32 {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
@ -114,7 +115,9 @@ impl FlashAttn {
|
|||||||
|
|
||||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||||
|
|
||||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
// Dropping the guard here doesn't seem very safe.
|
||||||
|
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||||
|
ptr as *const core::ffi::c_void
|
||||||
} else {
|
} else {
|
||||||
std::ptr::null()
|
std::ptr::null()
|
||||||
};
|
};
|
||||||
@ -161,17 +164,17 @@ impl FlashAttn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||||
ffi::run_mha(
|
ffi::run_mha(
|
||||||
q_ptr,
|
q_ptr as *const core::ffi::c_void,
|
||||||
k_ptr,
|
k_ptr as *const core::ffi::c_void,
|
||||||
v_ptr,
|
v_ptr as *const core::ffi::c_void,
|
||||||
dst_ptr,
|
dst_ptr as *const core::ffi::c_void,
|
||||||
softmax_lse_ptr,
|
softmax_lse_ptr as *const core::ffi::c_void,
|
||||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||||
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||||
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||||
@ -550,6 +553,7 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let batch_size = nseqlens_q - 1;
|
let batch_size = nseqlens_q - 1;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream();
|
||||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||||
if alibi_slopes.dtype() != DType::F32 {
|
if alibi_slopes.dtype() != DType::F32 {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
@ -576,7 +580,9 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||||
|
|
||||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
// Dropping the guard here doesn't seem very safe.
|
||||||
|
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||||
|
ptr as *const core::ffi::c_void
|
||||||
} else {
|
} else {
|
||||||
std::ptr::null()
|
std::ptr::null()
|
||||||
};
|
};
|
||||||
@ -621,22 +627,22 @@ impl FlashAttnVarLen {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||||
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
|
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
|
||||||
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
|
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
|
||||||
ffi::run_mha(
|
ffi::run_mha(
|
||||||
q_ptr,
|
q_ptr as *const core::ffi::c_void,
|
||||||
k_ptr,
|
k_ptr as *const core::ffi::c_void,
|
||||||
v_ptr,
|
v_ptr as *const core::ffi::c_void,
|
||||||
dst_ptr,
|
dst_ptr as *const core::ffi::c_void,
|
||||||
softmax_lse_ptr,
|
softmax_lse_ptr as *const core::ffi::c_void,
|
||||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
/* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void,
|
||||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
|
||||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
|
||||||
/* q_batch_stride */ 0,
|
/* q_batch_stride */ 0,
|
||||||
/* k_batch_stride */ 0,
|
/* k_batch_stride */ 0,
|
||||||
/* v_batch_stride */ 0,
|
/* v_batch_stride */ 0,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -7,5 +7,5 @@ fn main() {
|
|||||||
let builder = bindgen_cuda::Builder::default();
|
let builder = bindgen_cuda::Builder::default();
|
||||||
println!("cargo:info={builder:?}");
|
println!("cargo:info={builder:?}");
|
||||||
let bindings = builder.build_ptx().unwrap();
|
let bindings = builder.build_ptx().unwrap();
|
||||||
bindings.write("src/lib.rs").unwrap();
|
bindings.write("src/ptx.rs").unwrap();
|
||||||
}
|
}
|
||||||
|
@ -1,11 +1,78 @@
|
|||||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
mod ptx;
|
||||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
|
||||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
#[repr(u32)]
|
||||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
pub enum Id {
|
||||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
Affine,
|
||||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
Binary,
|
||||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
Cast,
|
||||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
Conv,
|
||||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
Fill,
|
||||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
Indexing,
|
||||||
|
Quantized,
|
||||||
|
Reduce,
|
||||||
|
Sort,
|
||||||
|
Ternary,
|
||||||
|
Unary,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const ALL_IDS: [Id; 11] = [
|
||||||
|
Id::Affine,
|
||||||
|
Id::Binary,
|
||||||
|
Id::Cast,
|
||||||
|
Id::Conv,
|
||||||
|
Id::Fill,
|
||||||
|
Id::Indexing,
|
||||||
|
Id::Quantized,
|
||||||
|
Id::Reduce,
|
||||||
|
Id::Sort,
|
||||||
|
Id::Ternary,
|
||||||
|
Id::Unary,
|
||||||
|
];
|
||||||
|
|
||||||
|
pub struct Module {
|
||||||
|
index: usize,
|
||||||
|
ptx: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module {
|
||||||
|
pub fn index(&self) -> usize {
|
||||||
|
self.index
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ptx(&self) -> &'static str {
|
||||||
|
self.ptx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn module_index(id: Id) -> usize {
|
||||||
|
let mut i = 0;
|
||||||
|
while i < ALL_IDS.len() {
|
||||||
|
if ALL_IDS[i] as u32 == id as u32 {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
panic!("id not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! mdl {
|
||||||
|
($cst:ident, $id:ident) => {
|
||||||
|
pub const $cst: Module = Module {
|
||||||
|
index: module_index(Id::$id),
|
||||||
|
ptx: ptx::$cst,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
mdl!(AFFINE, Affine);
|
||||||
|
mdl!(BINARY, Binary);
|
||||||
|
mdl!(CAST, Cast);
|
||||||
|
mdl!(CONV, Conv);
|
||||||
|
mdl!(FILL, Fill);
|
||||||
|
mdl!(INDEXING, Indexing);
|
||||||
|
mdl!(QUANTIZED, Quantized);
|
||||||
|
mdl!(REDUCE, Reduce);
|
||||||
|
mdl!(SORT, Sort);
|
||||||
|
mdl!(TERNARY, Ternary);
|
||||||
|
mdl!(UNARY, Unary);
|
||||||
|
11
candle-kernels/src/ptx.rs
Normal file
11
candle-kernels/src/ptx.rs
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||||
|
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||||
|
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||||
|
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||||
|
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||||
|
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||||
|
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||||
|
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||||
|
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||||
|
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||||
|
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
|
@ -7,7 +7,7 @@ use candle::{Result, Tensor};
|
|||||||
/// Arguments
|
/// Arguments
|
||||||
///
|
///
|
||||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||||
/// of categories. This is expected to contain log probabilities.
|
/// of categories. This is expected to contain log probabilities.
|
||||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
||||||
///
|
///
|
||||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||||
@ -34,7 +34,7 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
|||||||
/// Arguments
|
/// Arguments
|
||||||
///
|
///
|
||||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||||
/// of categories. This is expected to raw logits.
|
/// of categories. This is expected to raw logits.
|
||||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
||||||
///
|
///
|
||||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||||
@ -56,9 +56,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
|||||||
/// Arguments
|
/// Arguments
|
||||||
///
|
///
|
||||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||||
/// of categories. This is expected to raw logits.
|
/// of categories. This is expected to raw logits.
|
||||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
|
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
|
||||||
/// of categories.
|
/// of categories.
|
||||||
///
|
///
|
||||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||||
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||||
|
@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid {
|
|||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::SlicePtrOrNull;
|
use candle::cuda_backend::SlicePtrOrNull;
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||||
@ -110,13 +110,17 @@ impl candle::CustomOp1 for Sigmoid {
|
|||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||||
|
|
||||||
let params = (el_count, dims.len(), &ds, src, &out);
|
let mut builder = func.builder();
|
||||||
|
candle::builder_arg!(builder, el_count, dims.len());
|
||||||
|
ds.builder_arg(&mut builder);
|
||||||
|
builder.arg(src);
|
||||||
|
builder.arg(&out);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -340,7 +344,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -367,12 +371,15 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
block_dim: (1, 32, 1),
|
block_dim: (1, 32, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (&src, &dst, n_cols as i32);
|
let mut builder = func.builder();
|
||||||
|
builder.arg(&src);
|
||||||
|
builder.arg(&dst);
|
||||||
|
candle::builder_arg!(builder, n_cols as i32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -516,7 +523,7 @@ impl candle::CustomOp2 for RmsNorm {
|
|||||||
l2: &Layout,
|
l2: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -552,19 +559,16 @@ impl candle::CustomOp2 for RmsNorm {
|
|||||||
block_dim: (block_size, 1, 1),
|
block_dim: (block_size, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src,
|
builder.arg(&src);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
&alpha,
|
builder.arg(&alpha);
|
||||||
n_cols as i32,
|
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||||
block_size as i32,
|
|
||||||
self.eps,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -751,7 +755,7 @@ impl candle::CustomOp3 for LayerNorm {
|
|||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -793,20 +797,18 @@ impl candle::CustomOp3 for LayerNorm {
|
|||||||
block_dim: (block_size, 1, 1),
|
block_dim: (block_size, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
let func =
|
||||||
|
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src,
|
builder.arg(&src);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
&alpha,
|
builder.arg(&alpha);
|
||||||
&beta,
|
builder.arg(&beta);
|
||||||
n_cols as i32,
|
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||||
block_size as i32,
|
|
||||||
self.eps,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -117,12 +117,17 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
|
let mut builder = func.builder();
|
||||||
|
builder.arg(&src);
|
||||||
|
builder.arg(&cos);
|
||||||
|
builder.arg(&sin);
|
||||||
|
builder.arg(&dst);
|
||||||
|
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -333,7 +338,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -362,20 +367,17 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src,
|
builder.arg(&src);
|
||||||
&cos,
|
builder.arg(&cos);
|
||||||
&sin,
|
builder.arg(&sin);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
(b * h) as u32,
|
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
|
||||||
(t * d) as u32,
|
|
||||||
d as u32,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -587,7 +589,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -616,14 +618,17 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
|
builder.arg(&src);
|
||||||
);
|
builder.arg(&cos);
|
||||||
|
builder.arg(&sin);
|
||||||
|
builder.arg(&dst);
|
||||||
|
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.8.4" }
|
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.8.4" }
|
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.1" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
533
candle-transformers/src/models/csm.rs
Normal file
533
candle-transformers/src/models/csm.rs
Normal file
@ -0,0 +1,533 @@
|
|||||||
|
//! Implementation of the Conversational Speech Model (CSM) from Sesame
|
||||||
|
//!
|
||||||
|
//! See: [CSM](Conversational Speech Model)
|
||||||
|
//!
|
||||||
|
/// CSM (Conversational Speech Model) is a speech generation model from Sesame that generates RVQ
|
||||||
|
/// audio codes from text and audio inputs. The model architecture employs a Llama backbone and a
|
||||||
|
/// smaller audio decoder that produces Mimi audio codes.
|
||||||
|
///
|
||||||
|
use crate::generation::LogitsProcessor;
|
||||||
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum Flavor {
|
||||||
|
#[serde(rename = "llama-1B")]
|
||||||
|
Llama1B,
|
||||||
|
#[serde(rename = "llama-100M")]
|
||||||
|
Llama100M,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
pub audio_num_codebooks: usize,
|
||||||
|
pub audio_vocab_size: usize,
|
||||||
|
pub backbone_flavor: Flavor,
|
||||||
|
pub decoder_flavor: Flavor,
|
||||||
|
pub text_vocab_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LlamaConfig {
|
||||||
|
vocab_size: usize,
|
||||||
|
num_layers: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
embed_dim: usize,
|
||||||
|
max_seq_len: usize,
|
||||||
|
intermediate_dim: usize,
|
||||||
|
norm_eps: f64,
|
||||||
|
rope_base: f32,
|
||||||
|
scale_factor: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlamaConfig {
|
||||||
|
pub fn from_flavor(flavor: Flavor) -> Self {
|
||||||
|
match flavor {
|
||||||
|
Flavor::Llama1B => Self {
|
||||||
|
vocab_size: 128256,
|
||||||
|
num_layers: 16,
|
||||||
|
num_heads: 32,
|
||||||
|
num_kv_heads: 8,
|
||||||
|
embed_dim: 2048,
|
||||||
|
max_seq_len: 2048,
|
||||||
|
intermediate_dim: 8192,
|
||||||
|
norm_eps: 1e-5,
|
||||||
|
rope_base: 500_000.,
|
||||||
|
scale_factor: 32,
|
||||||
|
},
|
||||||
|
Flavor::Llama100M => Self {
|
||||||
|
vocab_size: 128256,
|
||||||
|
num_layers: 4,
|
||||||
|
num_heads: 8,
|
||||||
|
num_kv_heads: 2,
|
||||||
|
embed_dim: 1024,
|
||||||
|
max_seq_len: 2048,
|
||||||
|
intermediate_dim: 8192,
|
||||||
|
norm_eps: 1e-5,
|
||||||
|
rope_base: 500_000.,
|
||||||
|
scale_factor: 32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_default_inv_freq(cfg: &LlamaConfig) -> Vec<f32> {
|
||||||
|
let head_dim = cfg.embed_dim / cfg.num_heads;
|
||||||
|
(0..head_dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_base.powf(i as f32 / head_dim as f32))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result<Self> {
|
||||||
|
let low_freq_factor = 1.0;
|
||||||
|
let high_freq_factor = 4.0;
|
||||||
|
let original_max_position_embeddings = 8192;
|
||||||
|
let scale_factor = cfg.scale_factor as f32;
|
||||||
|
let theta = {
|
||||||
|
let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
|
||||||
|
let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
|
||||||
|
|
||||||
|
calculate_default_inv_freq(cfg)
|
||||||
|
.into_iter()
|
||||||
|
.map(|freq| {
|
||||||
|
let wavelen = 2. * std::f32::consts::PI / freq;
|
||||||
|
if wavelen < high_freq_wavelen {
|
||||||
|
freq
|
||||||
|
} else if wavelen > low_freq_wavelen {
|
||||||
|
freq / scale_factor
|
||||||
|
} else {
|
||||||
|
let smooth = (original_max_position_embeddings as f32 / wavelen
|
||||||
|
- low_freq_factor)
|
||||||
|
/ (high_freq_factor - low_freq_factor);
|
||||||
|
(1. - smooth) * freq / scale_factor + smooth * freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
let theta = Tensor::new(theta, dev)?;
|
||||||
|
let idx_theta = Tensor::arange(0, cfg.max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((cfg.max_seq_len, 1))?
|
||||||
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
|
// This is different from the paper, see:
|
||||||
|
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||||
|
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||||
|
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||||
|
Ok(Self { cos, sin })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn rms_norm(hidden_size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
|
||||||
|
let weight = vb.get((hidden_size,), "scale")?;
|
||||||
|
Ok(RmsNorm::new(weight, eps))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
num_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let head_dim = cfg.embed_dim / cfg.num_heads;
|
||||||
|
let kv_dim = cfg.num_kv_heads * head_dim;
|
||||||
|
|
||||||
|
let q_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("v_proj"))?;
|
||||||
|
let o_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("output_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache: None,
|
||||||
|
num_heads: cfg.num_heads,
|
||||||
|
num_kv_heads: cfg.num_kv_heads,
|
||||||
|
num_kv_groups: cfg.num_heads / cfg.num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b_sz, q_len, _) = xs.dims3()?;
|
||||||
|
|
||||||
|
let query_states = self.q_proj.forward(xs)?;
|
||||||
|
let key_states = self.k_proj.forward(xs)?;
|
||||||
|
let value_states = self.v_proj.forward(xs)?;
|
||||||
|
|
||||||
|
let query_states = query_states
|
||||||
|
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let key_states = key_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let value_states = value_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
|
||||||
|
let (query_states, key_states) =
|
||||||
|
self.rotary_emb
|
||||||
|
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||||
|
|
||||||
|
let (key_states, value_states) = match &self.kv_cache {
|
||||||
|
None => (key_states, value_states),
|
||||||
|
Some((prev_k, prev_v)) => {
|
||||||
|
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||||
|
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||||
|
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||||
|
|
||||||
|
let attn_output = {
|
||||||
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
|
let attn_weights = match attention_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||||
|
};
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
attn_weights.matmul(&value_states)?
|
||||||
|
};
|
||||||
|
attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_sz, q_len, self.num_heads * self.head_dim))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Mlp {
|
||||||
|
w1: Linear,
|
||||||
|
w2: Linear,
|
||||||
|
w3: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let w1 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w1"))?;
|
||||||
|
let w2 = linear_b(cfg.intermediate_dim, cfg.embed_dim, false, vb.pp("w2"))?;
|
||||||
|
let w3 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w3"))?;
|
||||||
|
Ok(Self { w1, w2, w3 })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Mlp {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let lhs = xs.apply(&self.w1)?.silu()?;
|
||||||
|
let rhs = xs.apply(&self.w3)?;
|
||||||
|
(lhs * rhs)?.apply(&self.w2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Layer {
|
||||||
|
mlp_norm: RmsNorm,
|
||||||
|
sa_norm: RmsNorm,
|
||||||
|
attn: Attention,
|
||||||
|
mlp: Mlp,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Layer {
|
||||||
|
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let mlp_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("mlp_norm"))?;
|
||||||
|
let sa_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("sa_norm"))?;
|
||||||
|
let attn = Attention::new(cfg, rotary_emb, vb.pp("attn"))?;
|
||||||
|
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
|
||||||
|
Ok(Self {
|
||||||
|
mlp_norm,
|
||||||
|
sa_norm,
|
||||||
|
attn,
|
||||||
|
mlp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = self.sa_norm.forward(xs)?;
|
||||||
|
let xs = self.attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
|
||||||
|
residual + xs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.attn.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LlamaModel {
|
||||||
|
layers: Vec<Layer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlamaModel {
|
||||||
|
pub fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_layers);
|
||||||
|
let vb_l = vb.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.num_layers {
|
||||||
|
let layer = Layer::new(cfg, rotary_emb.clone(), vb_l.pp(layer_idx))?;
|
||||||
|
layers.push(layer);
|
||||||
|
}
|
||||||
|
let norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
layer.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_decoder_attention_mask(
|
||||||
|
&self,
|
||||||
|
tgt_len: usize,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..tgt_len)
|
||||||
|
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||||
|
let mask = if seqlen_offset > 0 {
|
||||||
|
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||||
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
|
.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
|
let (_b_size, seq_len, _embed_dim) = xs.dims3()?;
|
||||||
|
let attention_mask = if seq_len <= 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
|
||||||
|
Some(mask)
|
||||||
|
};
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?;
|
||||||
|
}
|
||||||
|
let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?;
|
||||||
|
Ok(ys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
backbone: LlamaModel,
|
||||||
|
decoder: LlamaModel,
|
||||||
|
codebook0_head: Linear,
|
||||||
|
audio_embeddings: Embedding,
|
||||||
|
text_embeddings: Embedding,
|
||||||
|
projection: Linear,
|
||||||
|
audio_head: Tensor,
|
||||||
|
config: Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let backbone_cfg = LlamaConfig::from_flavor(cfg.backbone_flavor);
|
||||||
|
let backbone = LlamaModel::new(&backbone_cfg, vb.pp("backbone"))?;
|
||||||
|
let decoder_cfg = LlamaConfig::from_flavor(cfg.decoder_flavor);
|
||||||
|
let decoder = LlamaModel::new(&decoder_cfg, vb.pp("decoder"))?;
|
||||||
|
let backbone_dim = backbone_cfg.embed_dim;
|
||||||
|
let decoder_dim = decoder_cfg.embed_dim;
|
||||||
|
let audio_embeddings = embedding(
|
||||||
|
cfg.audio_vocab_size * cfg.audio_num_codebooks,
|
||||||
|
backbone_dim,
|
||||||
|
vb.pp("audio_embeddings"),
|
||||||
|
)?;
|
||||||
|
let text_embeddings =
|
||||||
|
embedding(cfg.text_vocab_size, backbone_dim, vb.pp("text_embeddings"))?;
|
||||||
|
let projection = linear_b(backbone_dim, decoder_dim, false, vb.pp("projection"))?;
|
||||||
|
let codebook0_head = linear_b(
|
||||||
|
backbone_dim,
|
||||||
|
cfg.audio_vocab_size,
|
||||||
|
false,
|
||||||
|
vb.pp("codebook0_head"),
|
||||||
|
)?;
|
||||||
|
let audio_head = vb.get(
|
||||||
|
(
|
||||||
|
cfg.audio_num_codebooks - 1,
|
||||||
|
decoder_dim,
|
||||||
|
cfg.audio_vocab_size,
|
||||||
|
),
|
||||||
|
"audio_head",
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
backbone,
|
||||||
|
decoder,
|
||||||
|
codebook0_head,
|
||||||
|
audio_embeddings,
|
||||||
|
text_embeddings,
|
||||||
|
projection,
|
||||||
|
audio_head,
|
||||||
|
config: cfg.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.backbone.clear_kv_cache();
|
||||||
|
self.decoder.clear_kv_cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_frame(
|
||||||
|
&mut self,
|
||||||
|
tokens: &Tensor,
|
||||||
|
tokens_mask: &Tensor,
|
||||||
|
input_pos: usize,
|
||||||
|
lp: &mut LogitsProcessor,
|
||||||
|
) -> Result<Vec<u32>> {
|
||||||
|
let (b_sz, seq_len, _cb_plus_one) = tokens.dims3()?;
|
||||||
|
let audio_tokens = tokens.narrow(2, 0, self.config.audio_num_codebooks)?;
|
||||||
|
let text_tokens = tokens.narrow(2, self.config.audio_num_codebooks, 1)?;
|
||||||
|
let text_embeds = self.text_embeddings.forward(&text_tokens)?;
|
||||||
|
let arange = (Tensor::arange(
|
||||||
|
0u32,
|
||||||
|
self.config.audio_num_codebooks as u32,
|
||||||
|
&self.decoder.device,
|
||||||
|
)? * self.config.audio_vocab_size as f64)?;
|
||||||
|
let audio_tokens = audio_tokens.broadcast_add(&arange.reshape((1, 1, ()))?)?;
|
||||||
|
let audio_embeds = self.audio_embeddings.forward(&audio_tokens)?.reshape((
|
||||||
|
b_sz,
|
||||||
|
seq_len,
|
||||||
|
self.config.audio_num_codebooks,
|
||||||
|
(),
|
||||||
|
))?;
|
||||||
|
let embeds = Tensor::cat(&[&audio_embeds, &text_embeds], D::Minus2)?;
|
||||||
|
let embeds = embeds.broadcast_mul(
|
||||||
|
&tokens_mask
|
||||||
|
.to_dtype(self.backbone.dtype)?
|
||||||
|
.unsqueeze(D::Minus1)?,
|
||||||
|
)?;
|
||||||
|
let embeds = embeds.sum(2)?;
|
||||||
|
let h = self.backbone.forward(&embeds, input_pos)?;
|
||||||
|
let c0_logits = h.apply(&self.codebook0_head)?;
|
||||||
|
let c0_sample = lp.sample(&c0_logits.i((0, 0))?)?;
|
||||||
|
let mut all_samples = vec![c0_sample];
|
||||||
|
let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?;
|
||||||
|
let c0_embed = self.audio_embeddings.forward(&c0_sample)?;
|
||||||
|
let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?;
|
||||||
|
|
||||||
|
self.decoder.clear_kv_cache();
|
||||||
|
let mut decoder_pos = 0;
|
||||||
|
for i in 1..self.config.audio_num_codebooks {
|
||||||
|
let proj_h = curr_h.apply(&self.projection)?;
|
||||||
|
let decoder_h = self.decoder.forward(&proj_h, decoder_pos)?;
|
||||||
|
decoder_pos += curr_h.dim(1)?;
|
||||||
|
let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i - 1)?)?;
|
||||||
|
let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?;
|
||||||
|
all_samples.push(ci_sample);
|
||||||
|
let ci_sample = Tensor::from_slice(
|
||||||
|
&[ci_sample + (i * self.config.audio_vocab_size) as u32],
|
||||||
|
(1, 1),
|
||||||
|
&self.decoder.device,
|
||||||
|
)?;
|
||||||
|
let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
|
||||||
|
curr_h = ci_embed
|
||||||
|
}
|
||||||
|
Ok(all_samples)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn audio_tokens_and_mask(&self, mut frame: Vec<u32>) -> Result<(Tensor, Tensor)> {
|
||||||
|
let cb = self.config.audio_num_codebooks;
|
||||||
|
let device = &self.backbone.device;
|
||||||
|
let mut mask = vec![1u8; cb];
|
||||||
|
mask.push(0);
|
||||||
|
let mask = Tensor::from_vec(mask, (1, 1, cb + 1), device)?;
|
||||||
|
|
||||||
|
frame.push(0);
|
||||||
|
let tokens = Tensor::from_vec(frame, (1, 1, cb + 1), device)?;
|
||||||
|
Ok((tokens, mask))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn text_tokens_and_mask(&self, ids: &[u32]) -> Result<(Tensor, Tensor)> {
|
||||||
|
let cb = self.config.audio_num_codebooks;
|
||||||
|
let device = &self.backbone.device;
|
||||||
|
let mut tokens = vec![];
|
||||||
|
let mut mask = vec![];
|
||||||
|
for &v in ids.iter() {
|
||||||
|
let mut token = vec![0; cb];
|
||||||
|
token.push(v);
|
||||||
|
let token = Tensor::from_vec(token, (1, 1, cb + 1), device)?;
|
||||||
|
tokens.push(token);
|
||||||
|
let mut m = vec![0u8; cb];
|
||||||
|
m.push(1);
|
||||||
|
let m = Tensor::from_vec(m, (1, 1, cb + 1), device)?;
|
||||||
|
mask.push(m);
|
||||||
|
}
|
||||||
|
let tokens = Tensor::cat(&tokens, 1)?;
|
||||||
|
let mask = Tensor::cat(&mask, 1)?;
|
||||||
|
Ok((tokens, mask))
|
||||||
|
}
|
||||||
|
}
|
@ -104,7 +104,7 @@ impl EncoderBlock {
|
|||||||
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
||||||
let cfg1 = Conv1dConfig {
|
let cfg1 = Conv1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: (stride + 1) / 2,
|
padding: stride.div_ceil(2),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
||||||
@ -196,7 +196,7 @@ impl DecoderBlock {
|
|||||||
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
||||||
let cfg = ConvTranspose1dConfig {
|
let cfg = ConvTranspose1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: (stride + 1) / 2,
|
padding: stride.div_ceil(2),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
||||||
|
@ -6,8 +6,8 @@ pub fn get_noise(
|
|||||||
width: usize,
|
width: usize,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let height = (height + 15) / 16 * 2;
|
let height = height.div_ceil(16) * 2;
|
||||||
let width = (width + 15) / 16 * 2;
|
let width = width.div_ceil(16) * 2;
|
||||||
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,8 +84,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
|
|||||||
|
|
||||||
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
||||||
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
||||||
let height = (height + 15) / 16;
|
let height = height.div_ceil(16);
|
||||||
let width = (width + 15) / 16;
|
let width = width.div_ceil(16);
|
||||||
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
||||||
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
||||||
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
||||||
|
@ -27,7 +27,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
(self.d_model + 15) / 16
|
self.d_model.div_ceil(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_inner(&self) -> usize {
|
fn d_inner(&self) -> usize {
|
||||||
|
@ -81,6 +81,126 @@ impl Config {
|
|||||||
vocab_size: 59514,
|
vocab_size: 59514,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_zh() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 65000,
|
||||||
|
decoder_vocab_size: Some(65001),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 65000,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 65001,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_hi() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 61949,
|
||||||
|
decoder_vocab_size: Some(61950),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 61949,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 61950,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_es() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 65000,
|
||||||
|
decoder_vocab_size: Some(65001),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 65000,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 65001,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_fr() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 59513,
|
||||||
|
decoder_vocab_size: Some(59514),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 59513,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 59514,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn opus_mt_en_ru() -> Self {
|
||||||
|
Self {
|
||||||
|
activation_function: candle_nn::Activation::Swish,
|
||||||
|
d_model: 512,
|
||||||
|
decoder_attention_heads: 8,
|
||||||
|
decoder_ffn_dim: 2048,
|
||||||
|
decoder_layers: 6,
|
||||||
|
decoder_start_token_id: 62517,
|
||||||
|
decoder_vocab_size: Some(62518),
|
||||||
|
encoder_attention_heads: 8,
|
||||||
|
encoder_ffn_dim: 2048,
|
||||||
|
encoder_layers: 6,
|
||||||
|
eos_token_id: 0,
|
||||||
|
forced_eos_token_id: 0,
|
||||||
|
is_encoder_decoder: true,
|
||||||
|
max_position_embeddings: 512,
|
||||||
|
pad_token_id: 62517,
|
||||||
|
scale_embedding: true,
|
||||||
|
share_encoder_decoder_embeddings: true,
|
||||||
|
use_cache: true,
|
||||||
|
vocab_size: 62518,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -716,7 +716,7 @@ pub mod transformer {
|
|||||||
None => {
|
None => {
|
||||||
let hidden_dim = self.dim * 4;
|
let hidden_dim = self.dim * 4;
|
||||||
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
|
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
|
||||||
(n_hidden + 255) / 256 * 256
|
n_hidden.div_ceil(256) * 256
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -27,6 +27,7 @@ pub mod codegeex4_9b;
|
|||||||
pub mod colpali;
|
pub mod colpali;
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
pub mod convnext;
|
pub mod convnext;
|
||||||
|
pub mod csm;
|
||||||
pub mod dac;
|
pub mod dac;
|
||||||
pub mod debertav2;
|
pub mod debertav2;
|
||||||
pub mod deepseek2;
|
pub mod deepseek2;
|
||||||
|
@ -198,7 +198,7 @@ pub fn log_mel_spectrogram_<T: Float>(
|
|||||||
let samples = {
|
let samples = {
|
||||||
let mut samples_padded = samples.to_vec();
|
let mut samples_padded = samples.to_vec();
|
||||||
let to_add = n_len * fft_step - samples.len();
|
let to_add = n_len * fft_step - samples.len();
|
||||||
samples_padded.extend(std::iter::repeat(zero).take(to_add));
|
samples_padded.extend(std::iter::repeat_n(zero, to_add));
|
||||||
samples_padded
|
samples_padded
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
|||||||
let samples = {
|
let samples = {
|
||||||
let mut samples_padded = samples.to_vec();
|
let mut samples_padded = samples.to_vec();
|
||||||
let to_add = n_len * fft_step - samples.len();
|
let to_add = n_len * fft_step - samples.len();
|
||||||
samples_padded.extend(std::iter::repeat(zero).take(to_add));
|
samples_padded.extend(std::iter::repeat_n(zero, to_add));
|
||||||
samples_padded
|
samples_padded
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user