Compare commits

..

18 Commits

Author SHA1 Message Date
cf9d7bf24c Add the CSM model. (#2862)
* Add the CSM model.

* Add some code to load the model.

* Load the text tokenizer.

* Add frame generation.

* Get the sampling to work.

* Rope fix.

* Autoregressive generation.

* Generate some audio file.

* Use the actual prompt.

* Support multiple turns.

* Add a very barebone readme.

* Move some of the shared bits to the model.
2025-04-04 06:48:03 +02:00
9d31361c4f Fix for clippy 1.86. (#2864)
* Fix for clippy 1.86.

* More clippy fixes.

* More fixes.
2025-04-03 19:38:27 +02:00
648596c073 Added readmes to examples (#2835)
* added chatGLM readme

* changed wording in readme

* added readme for chinese-clip

* added readme for convmixer

* added readme for custom ops

* added readme for efficientnet

* added readme for llama

* added readme to mnist-training

* added readme to musicgen

* added readme to quantized-phi

* added readme to starcoder2

* added readme to whisper-microphone

* added readme to yi

* added readme to yolo-v3

* added readme to whisper-microphone

* added space to example in glm4 readme

* fixed mamba example readme to run mamba instead of mamba-minimal

* removed slash escape character

* changed moondream image to yolo-v8 example image

* added procedure for making the reinforcement-learning example work with a virtual environment on my machine

* added simple one line summaries to the example readmes without

* changed non-existant image to yolo example's bike.jpg

* added backslash to sam command

* removed trailing - from siglip

* added SoX to silero-vad example readme

* replaced procedure for uv on mac with warning that uv isn't currently compatible with pyo3

* added example to falcon readme

* added --which arg to stella-en-v5 readme

* fixed image path in vgg readme

* fixed the image path in the vit readme

* Update README.md

* Update README.md

* Update README.md

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2025-04-03 09:18:29 +02:00
d9904a3baf Update to cudarc 0.14 (breaking change). (#2858)
* Start updating to cudarc 0.14.

* Adapt a couple more things.

* And a couple more fixes.

* More tweaks.

* And a couple more fixes.

* Bump the major version number.

* Proper module system for the cuda kernels.

* Proper ptx loading.

* Launch the sort kernel.

* Custom op.

* Start using the builder pattern.

* More builder.

* More builder.

* Get candle-core to compile.

* Get the tests to pass.

* Get candle-nn to work too.

* Support for custom cuda functions.

* cudnn fixes.

* Get flash attn to run.

* Switch the crate versions to be alpha.

* Bump the ug dependency.
2025-04-03 09:12:19 +02:00
d6db305829 Added new language pairs to marian-mt example. (#2860)
* added new language pairs to marian-mt

* lint

* seperated python code for converting tokenizers into its own file and and added a reqirements.txt for dependencies, updated instructions in readme and included python version

* Cleanup.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-04-02 23:50:14 +02:00
b4daa03e59 add as_cuda_slice_mut to CudaStorage and CudaDType (#2859) 2025-04-01 19:34:52 +02:00
9541467d6b Add flip to tensor (#2855)
* Add `flip` to `tensor`

* Move the tests to the proper places.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-04-01 09:07:16 +02:00
6429609090 Added Deepseekr1 Llama8b variant to quantized example (#2842)
* added deepseekr1 llama8b variant to quantized example

* lint
2025-03-30 10:55:21 +02:00
ba473290da Added DeepseekR1 Qwen7B variant to quantized-qwen2-instruct example (#2843)
* quantized deepseek qwen generating tokens

* removed is_deepseek from Args and replaced prompt if statement with pattern matching
2025-03-30 10:54:22 +02:00
59c26195db Fix CIFAR10 dataset types and dimension ordering (#2845) 2025-03-30 10:53:25 +02:00
cb02b389d5 Fix reinforcement learning example (#2837) 2025-03-26 16:27:45 +01:00
0d4097031c fixed rand import for mnist-training (#2833) 2025-03-26 08:10:03 +01:00
10853b803c fixed rand imports for whisper-microphone example (#2834) 2025-03-26 08:09:27 +01:00
f3d472952f fix: candle-flash-attn linux and msvc build (#2829)
* fix: candle-flash-attn linux and msvc build

* Missing newline at eof.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2025-03-25 08:45:12 +01:00
67b85f79f1 Pickle decoder fix and Long1 opcode addition. (#2824)
* Pickle decoder changes: added Long1 opcode, fixed tensor offset calculation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2025-03-23 08:10:08 +01:00
0b24f7f0a4 Fix for whisper example. rand::distribution is now rand::distr (#2811) 2025-03-16 19:14:55 +01:00
3afb04925a Allow for growing the default KV cache when needed. (#2810) 2025-03-16 17:30:25 +01:00
cbf5fc80c2 Add Gemma 3 1b IT toe Gemma examples (#2809)
- Updates the Gemma example to include Gemma 3 1b instruction tuned.
2025-03-16 17:00:48 +01:00
83 changed files with 2610 additions and 2107 deletions

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.8.4"
version = "0.9.0-alpha.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.8.4" }
candle-datasets = { path = "./candle-datasets", version = "0.8.4" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.4" }
candle-kernels = { path = "./candle-kernels", version = "0.8.4" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.4" }
candle-nn = { path = "./candle-nn", version = "0.8.4" }
candle-onnx = { path = "./candle-onnx", version = "0.8.4" }
candle-transformers = { path = "./candle-transformers", version = "0.8.4" }
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.1.0"
ug-cuda = "0.1.0"
ug-metal = "0.1.0"
ug = "0.2.0"
ug-cuda = "0.2.0"
ug-metal = "0.2.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -21,7 +21,9 @@ impl BenchDevice for Device {
Device::Cpu => Ok(()),
Device::Cuda(device) => {
#[cfg(feature = "cuda")]
return Ok(device.synchronize()?);
return Ok(device
.synchronize()
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
#[cfg(not(feature = "cuda"))]
panic!("Cuda device without cuda feature enabled: {:?}", device)
}

View File

@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_device());
let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv2d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv2d.launch::<CudaSlice<u8>, _, _, _>(
alg,

View File

@ -2,8 +2,9 @@ use crate::backend::BackendDevice;
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
pub use candle_kernels as kernels;
pub use cudarc;
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
use half::{bf16, f16};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
@ -24,10 +25,17 @@ impl DeviceId {
struct CudaRng(cudarc::curand::CudaRng);
unsafe impl Send for CudaRng {}
pub struct ModuleStore {
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
}
#[derive(Clone)]
pub struct CudaDevice {
id: DeviceId,
device: Arc<cudarc::driver::CudaDevice>,
context: Arc<cudarc::driver::CudaContext>,
modules: Arc<std::sync::RwLock<ModuleStore>>,
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
stream: Arc<cudarc::driver::CudaStream>,
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
curand: Arc<Mutex<CudaRng>>,
}
@ -39,16 +47,51 @@ impl std::fmt::Debug for CudaDevice {
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaDevice>;
type Target = Arc<cudarc::driver::CudaStream>;
fn deref(&self) -> &Self::Target {
&self.device
&self.stream
}
}
pub struct CudaFunc {
func: CudaFunction,
stream: Arc<cudarc::driver::CudaStream>,
}
impl std::ops::Deref for CudaFunc {
type Target = CudaFunction;
fn deref(&self) -> &Self::Target {
&self.func
}
}
impl CudaFunc {
pub fn into_cuda_function(self) -> CudaFunction {
self.func
}
}
#[macro_export]
macro_rules! builder_arg {
($b:ident, $($arg:expr),*) => {
$(
let __arg = $arg;
$b.arg(&__arg);
)*
};
}
impl CudaFunc {
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
self.stream.launch_builder(&self.func)
}
}
impl CudaDevice {
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
self.device.clone()
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
self.stream.clone()
}
#[cfg(not(target_arch = "wasm32"))]
@ -56,7 +99,7 @@ impl CudaDevice {
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunction> {
) -> Result<CudaFunc> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
@ -65,12 +108,12 @@ impl CudaDevice {
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
let module = self.context.load_module(ptx).w()?;
let func = module.load_function(func_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn id(&self) -> DeviceId {
@ -84,57 +127,84 @@ impl CudaDevice {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
let params = (&data, v as u8, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u8;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
let params = (&data, v as u32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
let params = (&data, v as i64, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as i64;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
let params = (&data, bf16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = bf16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
let params = (&data, f16::from_f64(v), elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = f16::from_f64(v);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
let params = (&data, v as f32, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as f32;
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
let params = (&data, v, elem_count);
unsafe { func.launch(cfg, params) }.w()?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
builder.arg(&v);
builder.arg(&elem_count);
unsafe { builder.launch(cfg) }.w()?;
CudaStorageSlice::F64(data)
}
};
@ -144,38 +214,69 @@ impl CudaDevice {
})
}
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
if !self.has_func(module_name, module_name) {
// Leaking the string here is a bit sad but we need a &'static str and this is only
// done once per kernel name.
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
self.load_ptx(ptx.into(), module_name, &[static_module_name])
.map_err(|cuda| CudaError::Load {
cuda,
module_name: module_name.to_string(),
})
.w()?;
pub fn get_or_load_custom_func(
&self,
fn_name: &str,
module_name: &str,
ptx: &str,
) -> Result<CudaFunc> {
let ms = self.custom_modules.read().unwrap();
if let Some(mdl) = ms.get(module_name).as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
self.get_func(module_name, module_name)
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
// able to only build the error value if needed.
.ok_or(CudaError::MissingKernel {
module_name: module_name.to_string(),
})
.w()
drop(ms);
let mut ms = self.custom_modules.write().unwrap();
let cuda_module = self.context.load_module(ptx.into()).w()?;
ms.insert(module_name.to_string(), cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
let ms = self.modules.read().unwrap();
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
let func = mdl.load_function(fn_name).w()?;
return Ok(CudaFunc {
func,
stream: self.stream.clone(),
});
}
drop(ms);
let mut ms = self.modules.write().unwrap();
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
ms.mdls[mdl.index()] = Some(cuda_module.clone());
let func = cuda_module.load_function(fn_name).w()?;
Ok(CudaFunc {
func,
stream: self.stream.clone(),
})
}
}
impl CudaDevice {
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.new_stream().w()?;
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
device,
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
}
@ -184,14 +285,21 @@ impl BackendDevice for CudaDevice {
type Storage = CudaStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
let stream = context.default_stream();
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
let module_store = ModuleStore {
mdls: [const { None }; kernels::ALL_IDS.len()],
};
Ok(Self {
id: DeviceId::new(),
device,
context,
stream,
blas: Arc::new(blas),
curand: Arc::new(Mutex::new(CudaRng(curand))),
modules: Arc::new(std::sync::RwLock::new(module_store)),
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
})
}
@ -199,13 +307,13 @@ impl BackendDevice for CudaDevice {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
Ok(())
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Cuda {
gpu_id: self.device.ordinal(),
gpu_id: self.context.ordinal(),
}
}
@ -373,31 +481,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -410,31 +518,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_sync_copy(storage).w()?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -447,31 +555,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.htod_copy(storage).w()?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -482,7 +590,7 @@ impl BackendDevice for CudaDevice {
}
fn synchronize(&self) -> Result<()> {
self.device.synchronize().map_err(crate::Error::wrap)?;
self.stream.synchronize().map_err(crate::Error::wrap)?;
Ok(())
}
}

File diff suppressed because it is too large Load Diff

View File

@ -396,7 +396,10 @@ impl UgIOp1 {
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
Ok(Self {
name,
func: func.into_cuda_function(),
})
}
#[cfg(feature = "metal")]
{
@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 {
#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;
use cudarc::driver::PushKernelArg;
let elem_count = layout.shape().elem_count();
let stream = sto.device.cuda_stream();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 {
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
let mut builder = stream.launch_builder(&self.func);
builder.arg(&sto);
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
}

View File

@ -45,6 +45,7 @@ pub enum OpCode {
BinFloat = b'G',
Append = b'a',
Appends = b'e',
Long1 = 0x8a,
}
// Avoid using FromPrimitive so as not to drag another dependency.
@ -84,6 +85,7 @@ impl TryFrom<u8> for OpCode {
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
0x8a => Ok(Self::Long1),
value => Err(value),
}
}
@ -106,6 +108,7 @@ pub enum Object {
class_name: String,
},
Int(i32),
Long(i64),
Float(f64),
Unicode(String),
Bool(bool),
@ -170,6 +173,14 @@ impl Object {
}
}
pub fn int_or_long(self) -> OResult<i64> {
match self {
Self::Int(t) => Ok(t as i64),
Self::Long(t) => Ok(t),
_ => Err(self),
}
}
pub fn tuple(self) -> OResult<Vec<Self>> {
match self {
Self::Tuple(t) => Ok(t),
@ -590,6 +601,15 @@ impl Stack {
let obj = self.new_obj(class, args)?;
self.push(obj)
}
OpCode::Long1 => {
let n_bytes = r.read_u8()?;
let mut v = 0;
// Decode the next n bytes in little endian
for i in 0..n_bytes {
v |= (r.read_u8()? as i64) << (i * 8);
}
self.push(Object::Long(v))
}
}
Ok(false)
}
@ -607,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize;
let offset = args.remove(1).int_or_long()? as usize;
let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize;
let storage_size = storage.remove(4).int_or_long()? as usize;
let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() {
@ -624,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
crate::bail!("unsupported storage type {other}")
}
};
let layout = Layout::new(crate::Shape::from(size), stride, offset);
let layout = Layout::new(
crate::Shape::from(size),
stride,
offset * dtype.size_in_bytes(),
);
Ok((layout, dtype, path, storage_size))
}
@ -792,7 +816,7 @@ impl PthTensors {
/// # Arguments
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,

View File

@ -1,10 +1,10 @@
use super::{GgmlDType, QStorage};
use crate::quantized::k_quants::GgmlType;
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
use crate::{CudaDevice, CudaStorage, Result};
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
use half::f16;
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
#[derive(Clone, Debug)]
struct PaddedCudaSlice {
@ -50,19 +50,20 @@ fn quantize_q8_1(
ky: usize,
dev: &CudaDevice,
) -> Result<()> {
use cudarc::driver::LaunchAsync;
let kx = elem_count;
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_blocks as u32, ky as u32, 1),
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
shared_mem_bytes: 0,
};
let params = (src, dst, kx as i32, kx_padded as i32);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(src);
builder.arg(dst);
barg!(builder, kx as i32, kx_padded as i32);
unsafe { builder.launch(cfg) }.w()?;
Ok(())
}
@ -72,8 +73,6 @@ fn dequantize_f32(
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
@ -99,7 +98,7 @@ fn dequantize_f32(
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
@ -110,15 +109,20 @@ fn dequantize_f32(
};
if is_k {
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
unsafe { builder.launch(cfg) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
barg!(builder, nb32 as i32);
unsafe { builder.launch(cfg) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -129,8 +133,6 @@ fn dequantize_f16(
elem_count: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let nb = (elem_count + 255) / 256;
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
@ -156,7 +158,7 @@ fn dequantize_f16(
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
@ -167,15 +169,20 @@ fn dequantize_f16(
};
if is_k {
let params = (&data.inner, &dst);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
unsafe { builder.launch(cfg) }.w()?;
} else {
let nb32 = match dtype {
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
_ => elem_count / 32,
};
let params = (&data.inner, &dst, nb32 as i32);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&dst);
barg!(builder, nb32 as i32);
unsafe { builder.launch(cfg) }.w()?;
}
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -188,8 +195,6 @@ fn dequantize_mul_mat_vec(
nrows: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
@ -210,7 +215,7 @@ fn dequantize_mul_mat_vec(
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let cfg = cudarc::driver::LaunchConfig {
@ -219,8 +224,12 @@ fn dequantize_mul_mat_vec(
shared_mem_bytes: 0,
};
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(y);
builder.arg(&dst);
barg!(builder, ncols as i32, nrows as i32);
unsafe { builder.launch(cfg) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -233,8 +242,6 @@ fn mul_mat_vec_via_q8_1(
b_size: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < ncols * nrows {
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
@ -266,7 +273,7 @@ fn mul_mat_vec_via_q8_1(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
@ -281,16 +288,18 @@ fn mul_mat_vec_via_q8_1(
shared_mem_bytes: 0,
};
let params = (
&data.inner,
&y_q8_1,
&dst,
let mut builder = func.builder();
builder.arg(&data.inner);
builder.arg(&y_q8_1);
builder.arg(&dst);
barg!(
builder,
/* ncols_x */ ncols as i32,
/* nrows_x */ nrows as i32,
/* nrows_y */ ncols_padded as i32,
/* nrows_dst */ nrows as i32,
/* nrows_dst */ nrows as i32
);
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -305,8 +314,6 @@ fn mul_mat_via_q8_1(
y_cols: usize,
dev: &CudaDevice,
) -> Result<CudaStorage> {
use cudarc::driver::LaunchAsync;
let data_elems = data.len / dtype.type_size() * dtype.block_size();
if data_elems < x_rows * x_cols {
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
@ -338,7 +345,7 @@ fn mul_mat_via_q8_1(
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
@ -350,17 +357,19 @@ fn mul_mat_via_q8_1(
shared_mem_bytes: 0,
};
let params = (
/* vx */ &data.inner,
/* vy */ &y_q8_1,
/* dst */ &dst,
let mut builder = func.builder();
builder.arg(/* vx */ &data.inner);
builder.arg(/* vy */ &y_q8_1);
builder.arg(/* dst */ &dst);
barg!(
builder,
/* ncols_x */ x_cols as i32,
/* nrows_x */ x_rows as i32,
/* ncols_y */ y_cols as i32,
/* nrows_y */ k_padded as i32,
/* nrows_dst */ x_rows as i32,
/* nrows_dst */ x_rows as i32
);
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
}
@ -416,7 +425,7 @@ impl QCudaStorage {
let buffer = self
.device
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
.memcpy_dtov(&self.data.inner.slice(..self.data.len))
.w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
@ -449,7 +458,7 @@ impl QCudaStorage {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.dtoh_sync_copy(data).w()?
self.device.memcpy_dtov(data).w()?
}
_ => crate::bail!("only f32 can be quantized"),
};
@ -462,7 +471,7 @@ impl QCudaStorage {
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
self.device
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))
.w()?;
self.data = PaddedCudaSlice {
inner,
@ -599,7 +608,7 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
device
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
.memcpy_htod(data, &mut inner.slice_mut(..data.len()))
.w()?;
Ok(QStorage::Cuda(QCudaStorage {
data: PaddedCudaSlice {
@ -624,7 +633,7 @@ mod test {
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let y = dev.memcpy_stod(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
@ -634,7 +643,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let y = dev.memcpy_stod(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
@ -647,7 +656,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
@ -662,7 +671,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
@ -673,7 +682,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let y = dev.memcpy_stod(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -687,7 +696,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
@ -714,7 +723,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let (x_rows, ncols, y_cols) = (4, 16, 2048);
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
let y = dev.htod_sync_copy(&vs).w()?;
let y = dev.memcpy_stod(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -728,7 +737,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
Ok(())
}
}

View File

@ -56,7 +56,7 @@ impl ArgSort {
mod cuda {
use super::*;
use crate::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
};
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
use crate::{CudaDevice, WithDType};
@ -69,6 +69,8 @@ mod cuda {
layout: &crate::Layout,
_wrap: W,
) -> Result<S> {
use cudarc::driver::PushKernelArg;
let slice = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
@ -76,20 +78,24 @@ mod cuda {
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
} else {
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
};
let ncols = self.last_dim;
let nrows = elem_count / ncols;
let ncols_pad = next_power_of_2(ncols);
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
let cfg = LaunchConfig {
grid_dim: (1, nrows as u32, 1),
block_dim: (ncols_pad as u32, 1, 1),
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
};
unsafe { func.launch(cfg, params) }.w()?;
let stream = dev.cuda_stream();
let mut builder = stream.launch_builder(&func);
let ncols = ncols as i32;
let ncols_pad = ncols_pad as i32;
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
unsafe { builder.launch(cfg) }.w()?;
Ok(S::U32(dst))
}
}

View File

@ -2580,6 +2580,28 @@ impl Tensor {
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}
/// Returns a new tensor with the order of elements reversed along the specified dimensions.
/// This function makes a copy of the tensors data.
///
/// ```rust
/// # use candle_core::{Tensor, Device};
/// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
/// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
/// let t_flipped = t.flip(&[0])?;
/// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
let mut result = self.clone();
for &dim in dims.iter() {
let size = result.dim(dim)?;
let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
result = result.index_select(&indices_tensor, dim)?;
}
Ok(result)
}
}
macro_rules! bin_trait {

View File

@ -24,6 +24,15 @@ macro_rules! test_device {
};
}
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
assert_eq!(t1.shape(), t2.shape());
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
let all_equal = eq_tensor.sum_all()?;
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
Ok(())
}
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
let b = 10f32.powi(digits);
let t = t.to_vec0::<f32>()?;

View File

@ -1,6 +1,6 @@
#![allow(clippy::approx_constant)]
use anyhow::{Context, Result};
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
fn simple_grad(device: &Device) -> Result<()> {
let x = Var::new(&[3f32, 1., 4.], device)?;
@ -505,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> {
Ok(())
}
#[test]
fn test_flip_backprop() -> Result<()> {
let device = &Device::Cpu;
// Create a tensor (leaf node) that requires gradients
let x = Var::ones((2, 2), DType::F64, device)?;
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
let y = x.matmul(&weights)?;
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
let z = y.flip(&[1])?;
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
let loss = z.sum_all()?;
let grad_store = loss.backward()?;
let grad_x = grad_store.get_id(x.id()).unwrap();
let flipped_weights = weights.flip(&[1])?;
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
Ok(())
}
test_device!(
simple_grad,
simple_grad_cpu,

View File

@ -1682,3 +1682,54 @@ fn pow() -> Result<()> {
);
Ok(())
}
#[test]
fn test_flip_1d() -> Result<()> {
// 1D: [0, 1, 2, 3, 4]
let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?;
let flipped = t.flip(&[0])?;
// Expected: [4, 3, 2, 1, 0]
let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn test_flip_2d() -> Result<()> {
// 2D:
// [[0, 1, 2],
// [3, 4, 5]]
let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?;
let flipped = t.flip(&[0, 1])?;
// Expected:
// [[5, 4, 3],
// [2, 1, 0]]
let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}
#[test]
fn test_flip_3d_channels() -> Result<()> {
// 3D:
// [[[0,1,2],
// [3,4,5]],
//
// [[6,7,8],
// [9,10,11]]]
let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?;
let flipped = t.flip(&[2])?;
// Expected:
// [[[2,1,0],
// [5,4,3]],
//
// [[8,7,6],
// [11,10,9]]]
let expected = Tensor::from_vec(
vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0],
(2, 2, 3),
&Device::Cpu,
)?;
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
Ok(())
}

View File

@ -72,6 +72,8 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
// image-rs crate convention is to load in (width, height, channels) order
// See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions
let image = image::load_from_memory(value.data()).unwrap();
buffer_images.extend(image.to_rgb8().as_raw());
}
@ -81,8 +83,10 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
}
}
}
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
.to_dtype(DType::U8)?
// Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width)
let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)?
.to_dtype(DType::F32)?
.permute((0, 3, 2, 1))?
/ 255.)?;
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
Ok((images, labels))

View File

@ -0,0 +1,13 @@
# candle-chatglm
Uses `THUDM/chatglm3-6b` to generate chinese text. Will not generate text for english (usually).
## Text Generation
```bash
cargo run --example chatglm --release -- --prompt "部署门槛较低等众多优秀特 "
> 部署门槛较低等众多优秀特 点使得其成为了一款备受欢迎的AI助手。
>
> 作为一款人工智能助手ChatGLM3-6B
```

View File

@ -0,0 +1,42 @@
# candle-chinese-clip
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
pairs of images with related texts. This one is trained using in chinese instead of english.
## Running on cpu
```bash
$ cargo run --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
>
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
>
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
>
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
```
## Running on metal
```bash
$ cargo run --features metal --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
>
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
>
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
>
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
```

View File

@ -0,0 +1,17 @@
# candle-convmixer
A lightweight CNN architecture that processes image patches similar to a vision transformer, with separate spatial and channel convolutions.
ConvMixer from [Patches Are All You Need?](https://arxiv.org/pdf/2201.09792) and [ConvMixer](https://github.com/locuslab/convmixer).
## Running an example
```bash
$ cargo run --example convmixer --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
> mountain bike, all-terrain bike, off-roader: 61.75%
> unicycle, monocycle : 5.73%
> moped : 3.66%
> bicycle-built-for-two, tandem bicycle, tandem: 3.51%
> crash helmet : 0.85%
```

View File

@ -0,0 +1,14 @@
# Conversational Speech Model (CSM)
CSM is a speech generation model from Sesame,
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
It can generate a conversational speech between two different speakers.
The speakers turn are delimited by the `|` character in the prompt.
```bash
cargo run --example csm --features cuda -r -- \
--voices voices.safetensors \
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
```

View File

@ -0,0 +1,243 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::csm::{Config, Model};
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "1b")]
Csm1b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
/// The prompt to be used for the generation, use a | to separate the speakers.
#[arg(long, default_value = "Hey how are you doing today?")]
prompt: String,
/// The voices to be used, in safetensors format.
#[arg(long)]
voices: String,
/// The output file using the wav format.
#[arg(long, default_value = "out.wav")]
out_file: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.7)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
/// The model size to use.
#[arg(long, default_value = "1b")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
config: Option<String>,
#[arg(long)]
weights: Option<String>,
/// The mimi model weight file, in safetensor format.
#[arg(long)]
mimi_weights: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let name = match args.which {
Which::Csm1b => "sesame/csm-1b",
};
name.to_string()
}
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let filenames = match args.weights {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => vec![repo.get("model.safetensors")?],
};
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("meta-llama/Llama-3.2-1B".to_string())
.get("tokenizer.json")?,
};
let mimi_filename = match args.mimi_weights {
Some(model) => std::path::PathBuf::from(model),
None => Api::new()?
.model("kyutai/mimi".to_string())
.get("model.safetensors")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = match args.config {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
};
let device = candle_examples::device(args.cpu)?;
let (mut model, device) = {
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
(model, device)
};
let mut mimi_model = {
use candle_transformers::models::mimi;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
let config = mimi::Config::v0_1(Some(32));
mimi::Model::new(config, vb)?
};
let cb = config.audio_num_codebooks;
println!("loaded the model in {:?}", start.elapsed());
let voices = candle::safetensors::load(args.voices, &device)?;
let mut lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
None,
);
let tokens = voices
.get("tokens")
.expect("no tokens in prompt")
.to_dtype(DType::U32)?;
let mask = voices.get("mask").expect("no mask in prompt").clone();
let mut pos = 0;
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
let mut all_pcms = vec![];
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
println!("{prompt:?}");
let speaker_idx = turn_idx % 2;
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
let mut generated_tokens = vec![];
loop {
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
let is_done = frame.iter().all(|&x| x == 0);
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
print!("\rframe {pos}");
if is_done {
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
pos += tokens.dim(1)?;
break;
}
generated_tokens.push(tokens.clone());
}
println!();
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
let pcm = mimi_model.decode(&generated_tokens)?;
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
all_pcms.push(pcm);
}
let pcm = Tensor::cat(&all_pcms, 0)?;
let pcm = pcm.to_vec1::<f32>()?;
println!("writing output file {}", args.out_file);
let mut output = std::fs::File::create(args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
Ok(())
}

View File

@ -0,0 +1,17 @@
# candle-custom-ops
This example illustrates how to implement forward and backward passes for custom operations on the CPU and GPU.
The custom op in this example implements RMS normalization for the CPU and CUDA.
## Running an example
```bash
$ cargo run --example custom-ops
> [[ 0., 1., 2., 3., 4., 5., 6.],
> [ 7., 8., 9., 10., 11., 12., 13.]]
> Tensor[[2, 7], f32]
> [[0.0000, 0.2773, 0.5547, 0.8320, 1.1094, 1.3867, 1.6641],
> [0.6864, 0.7845, 0.8825, 0.9806, 1.0786, 1.1767, 1.2748]]
> Tensor[[2, 7], f32]
```

View File

@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm {
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg};
use candle::cuda_backend::WrapErr;
let (d1, d2) = layout.shape().dims2()?;
let d1 = d1 as u32;
@ -69,14 +69,18 @@ impl CustomOp1 for LayerNorm {
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
let params = (&dst, &slice, self.eps, d1, d2);
let func =
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
let cfg = LaunchConfig {
grid_dim: (d1, 1, 1),
block_dim: (d2, 1, 1),
shared_mem_bytes: 0,
};
unsafe { func.launch(cfg, params) }.w()?;
let mut builder = func.builder();
builder.arg(&dst);
builder.arg(&slice);
candle::builder_arg!(builder, self.eps, d1, d2);
unsafe { builder.launch(cfg) }.w()?;
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
Ok((dst, layout.shape().clone()))

View File

@ -0,0 +1,15 @@
# candle-efficientnet
Demonstrates a Candle implementation of EfficientNet for image classification based on ImageNet classes.
## Running an example
```bash
$ cargo run --example efficientnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which b1
> bicycle-built-for-two, tandem bicycle, tandem: 45.85%
> mountain bike, all-terrain bike, off-roader: 30.45%
> crash helmet : 2.58%
> unicycle, monocycle : 2.21%
> tricycle, trike, velocipede: 1.53%
```

View File

@ -1,3 +1,10 @@
# candle-falcon
Falcon is a general large language model.
## Running an example
Make sure to include the `--use-f32` flag if using CPU, because there isn't a BFloat16 implementation yet.
```
cargo run --example falcon --release -- --prompt "Flying monkeys are" --use-f32
```

View File

@ -50,6 +50,8 @@ enum Which {
InstructV2_9B,
#[value(name = "3-1b")]
BaseV3_1B,
#[value(name = "3-1b-it")]
InstructV3_1B,
}
enum Model {
@ -272,6 +274,7 @@ fn main() -> Result<()> {
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
@ -292,13 +295,10 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
if args.which == Which::BaseV3_1B {
vec![repo.get("model.safetensors")?]
} else {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
}
None => match args.which {
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
@ -331,7 +331,7 @@ fn main() -> Result<()> {
let model = Model2::new(args.use_flash_attn, &config, vb)?;
Model::V2(model)
}
Which::BaseV3_1B => {
Which::BaseV3_1B | Which::InstructV3_1B => {
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
let model = Model3::new(args.use_flash_attn, &config, vb)?;
Model::V3(model)

View File

@ -12,7 +12,7 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
cargo run --example glm4 --release -- --cpu --prompt "Hello world"
#+end_src
** Output Example

View File

@ -0,0 +1,11 @@
# candle-llama
Candle implementations of various Llama based architectures.
## Running an example
```bash
$ cargo run --example llama -- --prompt "Machine learning is " --which v32-3b-instruct
> Machine learning is the part of computer science which deals with the development of algorithms and
```

View File

@ -21,7 +21,7 @@ impl Config {
}
fn dt_rank(&self) -> usize {
(self.d_model + 15) / 16
self.d_model.div_ceil(16)
}
fn d_conv(&self) -> usize {

View File

@ -12,6 +12,6 @@ would only work for inference.
## Running the example
```bash
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
$ cargo run --example mamba --release -- --prompt "Mamba is the"
```

View File

@ -18,21 +18,19 @@ I know you are waiting for me. I will go through the forest, I will go through t
mountain. I cannot stay far from you any longer.</s>
```
### Changing model and language pairs
```bash
$ cargo run --example marian-mt --release -- --text "hello, how are you." --which base --language-pair en-zh
你好,你好吗?
```
## Generating the tokenizer.json files
You can use the following script to generate the `tokenizer.json` config files
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
packages to be install and use the `convert_slow_tokenizer.py` script from this
directory.
```python
from convert_slow_tokenizer import MarianConverter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
```
The tokenizer for each `marian-mt` model was trained independently,
meaning each new model needs unique tokenizer encoders and decoders.
You can use the `./python/convert_slow_tokenizer.py` script in this directory to generate
the `tokenizer.json` config files from the hf-hub repos.
The script requires all the packages in `./python/requirements.txt` or `./python/uv.lock`
to be installed, and has only been tested for `python 3.12.7`.

File diff suppressed because it is too large Load Diff

View File

@ -20,6 +20,22 @@ enum Which {
Big,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum LanguagePair {
#[value(name = "fr-en")]
FrEn,
#[value(name = "en-zh")]
EnZh,
#[value(name = "en-hi")]
EnHi,
#[value(name = "en-es")]
EnEs,
#[value(name = "en-fr")]
EnFr,
#[value(name = "en-ru")]
EnRu,
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
@ -36,6 +52,10 @@ struct Args {
#[arg(long, default_value = "big")]
which: Which,
// Choose which language pair to use
#[arg(long, default_value = "fr-en")]
language_pair: LanguagePair,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
@ -53,21 +73,43 @@ pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let config = match args.which {
Which::Base => marian::Config::opus_mt_fr_en(),
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
let config = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => marian::Config::opus_mt_fr_en(),
(Which::Big, LanguagePair::FrEn) => marian::Config::opus_mt_tc_big_fr_en(),
(Which::Base, LanguagePair::EnZh) => marian::Config::opus_mt_en_zh(),
(Which::Base, LanguagePair::EnHi) => marian::Config::opus_mt_en_hi(),
(Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(),
(Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(),
(Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(),
(Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"),
};
let tokenizer_default_repo = match args.language_pair {
LanguagePair::FrEn => "lmz/candle-marian",
LanguagePair::EnZh
| LanguagePair::EnHi
| LanguagePair::EnEs
| LanguagePair::EnFr
| LanguagePair::EnRu => "KeighBee/candle-marian",
};
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-fr.json",
Which::Big => "tokenizer-marian-fr.json",
let filename = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-fr.json",
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-fr.json",
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-en.json",
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-en.json",
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json",
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json",
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json",
(Which::Big, lp) => {
anyhow::bail!("big is not supported for language pair {lp:?}")
}
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
.model(tokenizer_default_repo.to_string())
.get(filename)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
@ -77,13 +119,21 @@ pub fn main() -> anyhow::Result<()> {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-en.json",
Which::Big => "tokenizer-marian-en.json",
let filename = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-en.json",
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-en.json",
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-zh.json",
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-hi.json",
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json",
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json",
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json",
(Which::Big, lp) => {
anyhow::bail!("big is not supported for language pair {lp:?}")
}
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
.model(tokenizer_default_repo.to_string())
.get(filename)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
@ -94,18 +144,48 @@ pub fn main() -> anyhow::Result<()> {
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
None => {
let api = Api::new()?;
let api = match (args.which, args.language_pair) {
(Which::Base, LanguagePair::FrEn) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-fr-en".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
))
.get("model.safetensors")?,
Which::Big => Api::new()?
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
.get("model.safetensors")?,
},
)),
(Which::Big, LanguagePair::FrEn) => {
api.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
}
(Which::Base, LanguagePair::EnZh) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-zh".to_string(),
hf_hub::RepoType::Model,
"refs/pr/13".to_string(),
)),
(Which::Base, LanguagePair::EnHi) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-hi".to_string(),
hf_hub::RepoType::Model,
"refs/pr/3".to_string(),
)),
(Which::Base, LanguagePair::EnEs) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-es".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
)),
(Which::Base, LanguagePair::EnFr) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-fr".to_string(),
hf_hub::RepoType::Model,
"refs/pr/9".to_string(),
)),
(Which::Base, LanguagePair::EnRu) => api.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-en-ru".to_string(),
hf_hub::RepoType::Model,
"refs/pr/7".to_string(),
)),
(Which::Big, lp) => {
anyhow::bail!("big is not supported for language pair {lp:?}")
}
};
api.get("model.safetensors")?
}
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};

View File

@ -0,0 +1,53 @@
from pathlib import Path
import warnings
from transformers import AutoTokenizer
from transformers.convert_slow_tokenizer import SpmConverter, requires_backends, import_protobuf
class MarianConverter(SpmConverter):
def __init__(self, *args, index: int = 0):
requires_backends(self, "protobuf")
super(SpmConverter, self).__init__(*args)
# from .utils import sentencepiece_model_pb2 as model_pb2
model_pb2 = import_protobuf()
m = model_pb2.ModelProto()
print(self.original_tokenizer.spm_files)
with open(self.original_tokenizer.spm_files[index], "rb") as f:
m.ParseFromString(f.read())
self.proto = m
print(self.original_tokenizer)
#with open(self.original_tokenizer.vocab_path, "r") as f:
dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
with open(dir_path / "vocab.json", "r") as f:
import json
self._vocab = json.load(f)
if self.proto.trainer_spec.byte_fallback:
if not getattr(self, "handle_byte_fallback", None):
warnings.warn(
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
"unknown tokens into a sequence of byte tokens matching the original piece of text."
)
def vocab(self, proto):
vocab_size = max(self._vocab.values()) + 1
vocab = [("<NIL>", -100) for _ in range(vocab_size)]
for piece in proto.pieces:
try:
index = self._vocab[piece.piece]
except Exception:
print(f"Ignored missing piece {piece.piece}")
vocab[index] = (piece.piece, piece.score)
return vocab
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save("tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save("tokenizer-marian-base-en.json")

View File

@ -0,0 +1,22 @@
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
filelock==3.18.0
fsspec==2025.3.2
huggingface-hub==0.30.1
idna==3.10
joblib==1.4.2
numpy==2.2.4
packaging==24.2
protobuf==6.30.2
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
sacremoses==0.1.1
safetensors==0.5.3
sentencepiece==0.2.0
tokenizers==0.21.1
tqdm==4.67.1
transformers==4.50.3
typing-extensions==4.13.0
urllib3==2.3.0

View File

@ -13,6 +13,6 @@ Note that the current candle implementation suffers from some limitations as of
## Run an example
```bash
cargo run --example metavoice --release -- \\
cargo run --example metavoice --release -- \
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
```

View File

@ -0,0 +1,16 @@
# candle-mnist-training
Training a 2 layer MLP on mnist in Candle.
## Running an example
```bash
$ cargo run --example mnist-training --features candle-datasets
> train-images: [60000, 784]
> train-labels: [60000]
> test-images: [10000, 784]
> test-labels: [10000]
> 1 train loss: 2.30265 test acc: 68.08%
> 2 train loss: 1.50815 test acc: 60.77%
```

View File

@ -7,6 +7,7 @@ extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use rand::prelude::*;
use rand::rng;
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
@ -138,7 +139,7 @@ fn training_loop_cnn(
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
for epoch in 1..args.epochs {
let mut sum_loss = 0f32;
batch_idxs.shuffle(&mut thread_rng());
batch_idxs.shuffle(&mut rng());
for batch_idx in batch_idxs.iter() {
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;

View File

@ -12,7 +12,7 @@ $ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jp
Now you can run Moondream from the `candle-examples` crate:
```bash
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
$ cargo run --example moondream --release -- --prompt "Describe the people behind the bikers?" --image "candle-examples/examples/yolo-v8/assets/bike.jpg"
avavx: false, neon: true, simd128: false, f16c: false
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64

View File

@ -0,0 +1,20 @@
# candle-musicgen
Candle implementation of musicgen from [Simple and Controllable Music Generation](https://arxiv.org/pdf/2306.05284).
## Running an example
```bash
$ cargo run --example musicgen -- --prompt "90s rock song with loud guitars and heavy drums"
> tokens: [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7, 1]
> Tensor[dims 1, 13; u32]
> [[[ 0.0902, 0.1256, -0.0585, ..., 0.1057, -0.5141, -0.4675],
> [ 0.1972, -0.0268, -0.3368, ..., -0.0495, -0.3597, -0.3940],
> [-0.0855, -0.0007, 0.2225, ..., -0.2804, -0.5360, -0.2436],
> ...
> [ 0.0515, 0.0235, -0.3855, ..., -0.4728, -0.6858, -0.2923],
> [-0.3728, -0.1442, -0.1179, ..., -0.4388, -0.0287, -0.3242],
> [ 0.0163, 0.0012, -0.0020, ..., 0.0142, 0.0173, -0.0103]]]
> Tensor[[1, 13, 768], f32]
```

View File

@ -0,0 +1,20 @@
# candle-quantized-phi
Candle implementation of various quantized Phi models.
## Running an example
```bash
$ cargo run --example quantized-phi --release -- --prompt "The best thing about coding in rust is "
> - it's memory safe (without you having to worry too much)
> - the borrow checker is really smart and will catch your mistakes for free, making them show up as compile errors instead of segfaulting in runtime.
>
> This alone make me prefer using rust over c++ or go, python/Cython etc.
>
> The major downside I can see now:
> - it's slower than other languages (viz: C++) and most importantly lack of libraries to leverage existing work done by community in that language. There are so many useful machine learning libraries available for c++, go, python etc but none for Rust as far as I am aware of on the first glance.
> - there aren't a lot of production ready projects which also makes it very hard to start new one (given my background)
>
> Another downside:
```

View File

@ -27,6 +27,8 @@ enum Which {
W2_7b,
#[value(name = "72b")]
W2_72b,
#[value(name = "deepseekr1-qwen7b")]
DeepseekR1Qwen7B,
}
#[derive(Parser, Debug)]
@ -102,6 +104,7 @@ impl Args {
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
Which::DeepseekR1Qwen7B => "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
@ -135,6 +138,11 @@ impl Args {
"qwen2-72b-instruct-q4_0.gguf",
"main",
),
Which::DeepseekR1Qwen7B => (
"unsloth/DeepSeek-R1-Distill-Qwen-7B-GGUF",
"DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf",
"main",
),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
@ -211,11 +219,15 @@ fn main() -> anyhow::Result<()> {
let tokenizer = args.tokenizer()?;
let mut tos = TokenOutputStream::new(tokenizer);
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let prompt_str = format!(
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
prompt_str
);
let prompt_str = args
.prompt
.clone()
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
let prompt_str = match args.which {
Which::DeepseekR1Qwen7B => format!("<User>{prompt_str}<Assistant>"),
_ => format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n"),
};
print!("formatted instruct prompt: {}", &prompt_str);
let tokens = tos
.tokenizer()
@ -260,7 +272,13 @@ fn main() -> anyhow::Result<()> {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
let eos_token = match args.which {
Which::DeepseekR1Qwen7B => "<end▁of▁sentence>",
_ => "<|im_end|>",
};
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
let start_post_prompt = std::time::Instant::now();
let mut sampled = 0;
for index in 0..to_sample {

View File

@ -1,5 +1,7 @@
# candle-quantized-t5
Candle implementation for quantizing and running T5 translation models.
## Seq2Seq example
This example uses a quantized version of the t5 model.

View File

@ -75,6 +75,8 @@ enum Which {
SmolLM2_360MInstruct,
#[value(name = "SmoLM2-1.7B-Instruct")]
SmolLM2_1BInstruct,
#[value(name = "deepseekr1-llama8b")]
DeepseekR1Llama8b,
}
impl Which {
@ -94,7 +96,8 @@ impl Which {
| Self::L8b
| Self::Phi3
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct => false,
| Self::SmolLM2_360MInstruct
| Self::DeepseekR1Llama8b => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
Self::OpenChat35
@ -132,7 +135,8 @@ impl Which {
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3 => false,
| Self::Phi3
| Self::DeepseekR1Llama8b => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
@ -160,11 +164,41 @@ impl Which {
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3 => false,
| Self::Phi3
| Self::DeepseekR1Llama8b => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
}
fn is_deepseek(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::L8b
| Self::SmolLM2_1BInstruct
| Self::SmolLM2_360MInstruct
| Self::Phi3
| Self::OpenChat35
| Self::Starling7bAlpha => false,
Self::DeepseekR1Llama8b => true,
}
}
fn tokenizer_repo(&self) -> &'static str {
match self {
Self::L7b
@ -191,6 +225,7 @@ impl Which {
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
Self::DeepseekR1Llama8b => "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}
}
}
@ -363,6 +398,10 @@ impl Args {
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
"smollm2-1.7b-instruct-q4_k_m.gguf",
),
Which::DeepseekR1Llama8b => (
"unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
"DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf",
),
};
let revision = if self.which == Which::Phi3 {
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
@ -477,6 +516,7 @@ fn main() -> anyhow::Result<()> {
| Which::L8b
| Which::SmolLM2_1BInstruct
| Which::SmolLM2_360MInstruct
| Which::DeepseekR1Llama8b
| Which::Phi3 => 1,
Which::Mixtral
| Which::MixtralInstruct
@ -530,6 +570,8 @@ fn main() -> anyhow::Result<()> {
}
} else if args.which.is_mistral() {
format!("[INST] {prompt} [/INST]")
} else if args.which.is_deepseek() {
format!("<User>{prompt}<Assistant>")
} else {
prompt
}
@ -597,6 +639,7 @@ fn main() -> anyhow::Result<()> {
let eos_token = match args.which {
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
Which::L8b => "<|end_of_text|>",
Which::DeepseekR1Llama8b => "<end▁of▁sentence>",
_ => match args.which.is_open_chat() {
true => "<|end_of_turn|>",
false => "</s>",

View File

@ -2,6 +2,11 @@
Reinforcement Learning examples for candle.
> [!WARNING]
> uv is not currently compatible with pyo3 as of 2025/3/28.
## System wide python
This has been tested with `gymnasium` version `0.29.1`. You can install the
Python package with:
```bash

View File

@ -5,7 +5,7 @@ use candle_nn::{
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
VarBuilder, VarMap,
};
use rand::{distributions::Uniform, thread_rng, Rng};
use rand::{distr::Uniform, rng, Rng};
use super::gym_env::GymEnv;
@ -103,8 +103,8 @@ impl ReplayBuffer {
if self.size < batch_size {
Ok(None)
} else {
let transitions: Vec<&Transition> = thread_rng()
.sample_iter(Uniform::from(0..self.size))
let transitions: Vec<&Transition> = rng()
.sample_iter(Uniform::try_from(0..self.size).map_err(Error::wrap)?)
.take(batch_size)
.map(|i| self.buffer.get(i).unwrap())
.collect();
@ -498,11 +498,11 @@ pub fn run() -> Result<()> {
OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut state = env.reset(rng.random::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
@ -538,7 +538,7 @@ pub fn run() -> Result<()> {
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut state = env.reset(rng.random::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;

View File

@ -1,9 +1,8 @@
use std::collections::VecDeque;
use rand::distributions::Uniform;
use rand::{thread_rng, Rng};
use rand::{distr::Uniform, rng, Rng};
use candle::{DType, Device, Module, Result, Tensor};
use candle::{DType, Device, Error, Module, Result, Tensor};
use candle_nn::loss::mse;
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
@ -65,8 +64,8 @@ pub fn run() -> Result<()> {
// fed to the model so that it performs a backward pass.
if memory.len() > BATCH_SIZE {
// Sample randomly from the memory.
let batch = thread_rng()
.sample_iter(Uniform::from(0..memory.len()))
let batch = rng()
.sample_iter(Uniform::try_from(0..memory.len()).map_err(Error::wrap)?)
.take(BATCH_SIZE)
.map(|i| memory.get(i).unwrap().clone())
.collect::<Vec<_>>();

View File

@ -4,7 +4,7 @@ use candle_nn::{
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
ParamsAdamW, VarBuilder, VarMap,
};
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
use rand::{distr::Distribution, rngs::ThreadRng, Rng};
fn new_model(
input_shape: &[usize],
@ -39,7 +39,7 @@ fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
}
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
let distribution = rand::distr::weighted::WeightedIndex::new(probs).map_err(Error::wrap)?;
let mut rng = rng;
Ok(distribution.sample(&mut rng))
}
@ -65,10 +65,10 @@ pub fn run() -> Result<()> {
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
let mut rng = rand::thread_rng();
let mut rng = rand::rng();
for epoch_idx in 0..100 {
let mut state = env.reset(rng.gen::<u64>())?;
let mut state = env.reset(rng.random::<u64>())?;
let mut steps: Vec<Step<i64>> = vec![];
loop {
@ -84,7 +84,7 @@ pub fn run() -> Result<()> {
steps.push(step.copy_with_obs(&state));
if step.terminated || step.truncated {
state = env.reset(rng.gen::<u64>())?;
state = env.reset(rng.random::<u64>())?;
if steps.len() > 5000 {
break;
}

View File

@ -7,7 +7,7 @@ probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example resnet --release -- --image tiger.jpg
$ cargo run --example resnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built

View File

@ -10,9 +10,11 @@ If you want you can use the example images from this [pull request][pr], downloa
```bash
# run the image classification task
cargo run --example segformer classify <path-to-image>
cargo run --example segformer classify candle-examples/examples/yolo-v8/assets/bike.jpg
# run the segmentation task
cargo run --example segformer segment <path-to-image>
cargo run --example segformer segment candle-examples/examples/yolo-v8/assets/bike.jpg
```
Example output for classification:

View File

@ -14,8 +14,8 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
```bash
cargo run --example segment-anything --release -- \
--image candle-examples/examples/yolo-v8/assets/bike.jpg
--use-tiny
--image candle-examples/examples/yolo-v8/assets/bike.jpg \
--use-tiny \
--point 0.6,0.6 --point 0.6,0.55
```

View File

@ -5,7 +5,7 @@ SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmo
### Running an example
```
$ cargo run --features cuda -r --example siglip -
$ cargo run --features cuda -r --example siglip
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]

View File

@ -6,7 +6,14 @@ This example uses the models available in the hugging face [onnx-community/siler
## Running the example
### using arecord
```bash
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
```
### using SoX
```bash
$ rec -t raw -r 48000 -b 16 -c 1 -e signed-integer - trim 0 5 | sox -t raw -r 48000 -b 16 -c 1 -e signed-integer - -t raw -r 16000 -b 16 -c 1 -e signed-integer - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
```

View File

@ -0,0 +1,15 @@
# candle-starcoder2
Candle implementation of Star Coder 2 family of code generation model from [StarCoder 2 and The Stack v2: The Next Generation](https://arxiv.org/pdf/2402.19173).
## Running an example
```bash
$ cargo run --example starcoder2 -- --prompt "write a recursive fibonacci function in python "
> # that returns the nth number in the sequence.
>
> def fib(n):
> if n
```

View File

@ -10,7 +10,7 @@ Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. T
are downloaded from the hub on the first run.
```bash
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?"
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?" --which 1.5b
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
> Tensor[[1, 1024], f32]

View File

@ -1,5 +1,7 @@
# candle-t5
Candle implementations of the T5 family of translation models.
## Encoder-decoder example:
```bash

View File

@ -7,7 +7,7 @@ The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main
You can run the example with the following command:
```bash
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
cargo run --example vgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which vgg13
```
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).

View File

@ -7,8 +7,8 @@ probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example vit --release -- --image tiger.jpg
```bash
$ cargo run --example vit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built

View File

@ -0,0 +1,15 @@
# candle-whisper-microphone
Whisper implementation using microphone as input.
## Running an example
```bash
$ cargo run --example whisper-microphone --features microphone
> transcribing audio...
> 480256 160083
> language_token: None
> 0.0s -- 30.0s: Hello, hello, I don't know if this is working, but You know, how long did I make this?
> 480256 160085
```

View File

@ -9,7 +9,7 @@ use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use rand::{distr::Distribution, SeedableRng};
use tokenizers::Tokenizer;
mod multilingual;
@ -204,7 +204,7 @@ impl Decoder {
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;

View File

@ -14,7 +14,9 @@ use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use rand::distr::weighted::WeightedIndex;
use rand::distr::Distribution;
use rand::SeedableRng;
use tokenizers::Tokenizer;
mod multilingual;
@ -208,7 +210,7 @@ impl Decoder {
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
let distr = WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;

View File

@ -0,0 +1,13 @@
# candle-yi
Candle implentations of the Yi family of bilingual (English, Chinese) LLMs.
## Running an example
```bash
$ cargo run --example yi -- --prompt "Here is a test sentence"
> python
> print("Hello World")
>
```

View File

@ -0,0 +1,32 @@
# candle-yolo-v3:
Candle implementation of Yolo-V3 for object detection.
## Running an example
```bash
$ cargo run --example yolo-v3 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
> generated predictions Tensor[dims 10647, 85; f32]
> person: Bbox { xmin: 46.362198, ymin: 72.177, xmax: 135.92522, ymax: 339.8356, confidence: 0.99705493, data: () }
> person: Bbox { xmin: 137.25645, ymin: 67.58148, xmax: 216.90437, ymax: 333.80756, confidence: 0.9898516, data: () }
> person: Bbox { xmin: 245.7842, ymin: 82.76726, xmax: 316.79053, ymax: 337.21613, confidence: 0.9884322, data: () }
> person: Bbox { xmin: 207.52783, ymin: 61.815224, xmax: 266.77884, ymax: 307.92606, confidence: 0.9860648, data: () }
> person: Bbox { xmin: 11.457404, ymin: 60.335564, xmax: 34.39357, ymax: 187.7714, confidence: 0.9545012, data: () }
> person: Bbox { xmin: 251.88353, ymin: 11.235481, xmax: 286.56607, ymax: 92.54697, confidence: 0.8439807, data: () }
> person: Bbox { xmin: -0.44309902, ymin: 55.486923, xmax: 13.160354, ymax: 184.09705, confidence: 0.8266243, data: () }
> person: Bbox { xmin: 317.40826, ymin: 55.39501, xmax: 370.6704, ymax: 153.74887, confidence: 0.7327442, data: () }
> person: Bbox { xmin: 370.02835, ymin: 66.120224, xmax: 404.22824, ymax: 142.09691, confidence: 0.7265741, data: () }
> person: Bbox { xmin: 250.36511, ymin: 57.349842, xmax: 280.06335, ymax: 116.29384, confidence: 0.709422, data: () }
> person: Bbox { xmin: 32.573215, ymin: 66.66239, xmax: 50.49056, ymax: 173.42068, confidence: 0.6998766, data: () }
> person: Bbox { xmin: 131.72215, ymin: 63.946213, xmax: 166.66151, ymax: 241.52773, confidence: 0.64457536, data: () }
> person: Bbox { xmin: 407.42416, ymin: 49.106407, xmax: 415.24307, ymax: 84.7134, confidence: 0.5955802, data: () }
> person: Bbox { xmin: 51.650482, ymin: 64.4985, xmax: 67.40904, ymax: 106.952385, confidence: 0.5196007, data: () }
> bicycle: Bbox { xmin: 160.10031, ymin: 183.90837, xmax: 200.86832, ymax: 398.609, confidence: 0.9623588, data: () }
> bicycle: Bbox { xmin: 66.570915, ymin: 192.56966, xmax: 112.06765, ymax: 369.28497, confidence: 0.9174347, data: () }
> bicycle: Bbox { xmin: 258.2856, ymin: 197.04532, xmax: 298.43106, ymax: 364.8627, confidence: 0.6851388, data: () }
> bicycle: Bbox { xmin: 214.0034, ymin: 175.76498, xmax: 252.45158, ymax: 356.53818, confidence: 0.67071193, data: () }
> motorbike: Bbox { xmin: 318.23938, ymin: 95.22487, xmax: 369.9743, ymax: 213.46263, confidence: 0.96691036, data: () }
> motorbike: Bbox { xmin: 367.46417, ymin: 100.07982, xmax: 394.9981, ymax: 174.6545, confidence: 0.9185384, data: () }
> writing "candle-examples/examples/yolo-v8/assets/bike.pp.jpg"
```

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.8.4"
version = "0.9.0-alpha.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.4" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]

View File

@ -88,19 +88,26 @@ fn main() -> Result<()> {
.arg("--use_fast_math")
.arg("--verbose");
let mut is_target_msvc = false;
if let Ok(target) = std::env::var("TARGET") {
if target.contains("msvc") {
is_target_msvc = true;
builder = builder.arg("-D_USE_MATH_DEFINES");
}
}
if !is_target_msvc {
builder = builder.arg("-Xcompiler").arg("-fPIC");
}
let out_file = build_dir.join("libflashattention.a");
builder.build_lib(out_file);
println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++");
if !is_target_msvc {
println!("cargo:rustc-link-lib=dylib=stdc++");
}
Ok(())
}

View File

@ -88,6 +88,7 @@ impl FlashAttn {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
}
let stream = dev.cuda_stream();
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
@ -114,7 +115,9 @@ impl FlashAttn {
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
// Dropping the guard here doesn't seem very safe.
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
ptr as *const core::ffi::c_void
} else {
std::ptr::null()
};
@ -161,17 +164,17 @@ impl FlashAttn {
}
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
let (q_ptr, _guard) = q.device_ptr(&stream);
let (k_ptr, _guard) = k.device_ptr(&stream);
let (v_ptr, _guard) = v.device_ptr(&stream);
let (dst_ptr, _guard) = dst.device_ptr(&stream);
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
ffi::run_mha(
q_ptr,
k_ptr,
v_ptr,
dst_ptr,
softmax_lse_ptr,
q_ptr as *const core::ffi::c_void,
k_ptr as *const core::ffi::c_void,
v_ptr as *const core::ffi::c_void,
dst_ptr as *const core::ffi::c_void,
softmax_lse_ptr as *const core::ffi::c_void,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(),
@ -550,6 +553,7 @@ impl FlashAttnVarLen {
let batch_size = nseqlens_q - 1;
let stream = dev.cuda_stream();
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
@ -576,7 +580,9 @@ impl FlashAttnVarLen {
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
// Dropping the guard here doesn't seem very safe.
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
ptr as *const core::ffi::c_void
} else {
std::ptr::null()
};
@ -621,22 +627,22 @@ impl FlashAttnVarLen {
}
unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
let (q_ptr, _guard) = q.device_ptr(&stream);
let (k_ptr, _guard) = k.device_ptr(&stream);
let (v_ptr, _guard) = v.device_ptr(&stream);
let (dst_ptr, _guard) = dst.device_ptr(&stream);
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
ffi::run_mha(
q_ptr,
k_ptr,
v_ptr,
dst_ptr,
softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
q_ptr as *const core::ffi::c_void,
k_ptr as *const core::ffi::c_void,
v_ptr as *const core::ffi::c_void,
dst_ptr as *const core::ffi::c_void,
softmax_lse_ptr as *const core::ffi::c_void,
/* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void,
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
/* q_batch_stride */ 0,
/* k_batch_stride */ 0,
/* v_batch_stride */ 0,

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.8.4"
version = "0.9.0-alpha.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -7,5 +7,5 @@ fn main() {
let builder = bindgen_cuda::Builder::default();
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write("src/lib.rs").unwrap();
bindings.write("src/ptx.rs").unwrap();
}

View File

@ -1,11 +1,78 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
mod ptx;
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Id {
Affine,
Binary,
Cast,
Conv,
Fill,
Indexing,
Quantized,
Reduce,
Sort,
Ternary,
Unary,
}
pub const ALL_IDS: [Id; 11] = [
Id::Affine,
Id::Binary,
Id::Cast,
Id::Conv,
Id::Fill,
Id::Indexing,
Id::Quantized,
Id::Reduce,
Id::Sort,
Id::Ternary,
Id::Unary,
];
pub struct Module {
index: usize,
ptx: &'static str,
}
impl Module {
pub fn index(&self) -> usize {
self.index
}
pub fn ptx(&self) -> &'static str {
self.ptx
}
}
const fn module_index(id: Id) -> usize {
let mut i = 0;
while i < ALL_IDS.len() {
if ALL_IDS[i] as u32 == id as u32 {
return i;
}
i += 1;
}
panic!("id not found")
}
macro_rules! mdl {
($cst:ident, $id:ident) => {
pub const $cst: Module = Module {
index: module_index(Id::$id),
ptx: ptx::$cst,
};
};
}
mdl!(AFFINE, Affine);
mdl!(BINARY, Binary);
mdl!(CAST, Cast);
mdl!(CONV, Conv);
mdl!(FILL, Fill);
mdl!(INDEXING, Indexing);
mdl!(QUANTIZED, Quantized);
mdl!(REDUCE, Reduce);
mdl!(SORT, Sort);
mdl!(TERNARY, Ternary);
mdl!(UNARY, Unary);

11
candle-kernels/src/ptx.rs Normal file
View File

@ -0,0 +1,11 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.8.4"
version = "0.9.0-alpha.1"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -11,6 +11,7 @@ pub struct Cache {
all_data: Option<Tensor>,
dim: usize,
current_seq_len: usize,
grow_by: usize,
max_seq_len: usize,
}
@ -20,6 +21,7 @@ impl Cache {
all_data: None,
dim,
current_seq_len: 0,
grow_by: max_seq_len,
max_seq_len,
}
}
@ -65,11 +67,11 @@ impl Cache {
};
let ad = self.all_data.as_mut().unwrap();
if self.current_seq_len + seq_len > self.max_seq_len {
candle::bail!(
"kv-cache: above max-seq-len {}+{seq_len}>{}",
self.current_seq_len,
self.max_seq_len
)
let mut shape = src.dims().to_vec();
shape[self.dim] = self.grow_by;
let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
*ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
self.max_seq_len += self.grow_by;
}
ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;

View File

@ -7,7 +7,7 @@ use candle::{Result, Tensor};
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to contain log probabilities.
/// of categories. This is expected to contain log probabilities.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
///
/// The resulting tensor is a scalar containing the average value over the batch.
@ -34,7 +34,7 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
///
/// The resulting tensor is a scalar containing the average value over the batch.
@ -56,9 +56,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
/// of categories.
/// of categories.
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {

View File

@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid {
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
};
use candle::cuda_backend::SlicePtrOrNull;
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
@ -110,13 +110,17 @@ impl candle::CustomOp1 for Sigmoid {
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let params = (el_count, dims.len(), &ds, src, &out);
let mut builder = func.builder();
candle::builder_arg!(builder, el_count, dims.len());
ds.builder_arg(&mut builder);
builder.arg(src);
builder.arg(&out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(out)
}
}
@ -340,7 +344,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
use candle::{CudaDevice, WithDType};
@ -367,12 +371,15 @@ impl candle::CustomOp1 for SoftmaxLastDim {
block_dim: (1, 32, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, n_cols as i32);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
candle::builder_arg!(builder, n_cols as i32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
}
@ -516,7 +523,7 @@ impl candle::CustomOp2 for RmsNorm {
l2: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
use candle::{CudaDevice, WithDType};
@ -552,19 +559,16 @@ impl candle::CustomOp2 for RmsNorm {
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
n_cols as i32,
block_size as i32,
self.eps,
);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
builder.arg(&alpha);
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
}
@ -751,7 +755,7 @@ impl candle::CustomOp3 for LayerNorm {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
use candle::{CudaDevice, WithDType};
@ -793,20 +797,18 @@ impl candle::CustomOp3 for LayerNorm {
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
let func =
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&dst,
&alpha,
&beta,
n_cols as i32,
block_size as i32,
self.eps,
);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
builder.arg(&alpha);
builder.arg(&beta);
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
}

View File

@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
@ -117,12 +117,17 @@ impl candle::CustomOp3 for RotaryEmbI {
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
@ -333,7 +338,7 @@ impl candle::CustomOp3 for RotaryEmb {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
@ -362,20 +367,17 @@ impl candle::CustomOp3 for RotaryEmb {
let (b, h, t, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src,
&cos,
&sin,
&dst,
(b * h) as u32,
(t * d) as u32,
d as u32,
);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}
@ -587,7 +589,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
};
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
use candle::{CudaDevice, WithDType};
@ -616,14 +618,17 @@ impl candle::CustomOp3 for RotaryEmbThd {
let (b, t, h, d) = l_src.shape().dims4()?;
let el = b * h * t * d;
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
);
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
builder.arg(&sin);
builder.arg(&dst);
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
unsafe { builder.launch(cfg) }.w()?;
Ok(dst)
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.8.4"
version = "0.9.0-alpha.1"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.8.4" }
candle-nn = { path = "../candle-nn", version = "0.8.4" }
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.1" }
prost = "0.12.1"
[build-dependencies]

View File

@ -0,0 +1,533 @@
//! Implementation of the Conversational Speech Model (CSM) from Sesame
//!
//! See: [CSM](Conversational Speech Model)
//!
/// CSM (Conversational Speech Model) is a speech generation model from Sesame that generates RVQ
/// audio codes from text and audio inputs. The model architecture employs a Llama backbone and a
/// smaller audio decoder that produces Mimi audio codes.
///
use crate::generation::LogitsProcessor;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
use std::sync::Arc;
#[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum Flavor {
#[serde(rename = "llama-1B")]
Llama1B,
#[serde(rename = "llama-100M")]
Llama100M,
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub audio_num_codebooks: usize,
pub audio_vocab_size: usize,
pub backbone_flavor: Flavor,
pub decoder_flavor: Flavor,
pub text_vocab_size: usize,
}
#[allow(unused)]
#[derive(Debug, Clone)]
pub struct LlamaConfig {
vocab_size: usize,
num_layers: usize,
num_heads: usize,
num_kv_heads: usize,
embed_dim: usize,
max_seq_len: usize,
intermediate_dim: usize,
norm_eps: f64,
rope_base: f32,
scale_factor: usize,
}
impl LlamaConfig {
pub fn from_flavor(flavor: Flavor) -> Self {
match flavor {
Flavor::Llama1B => Self {
vocab_size: 128256,
num_layers: 16,
num_heads: 32,
num_kv_heads: 8,
embed_dim: 2048,
max_seq_len: 2048,
intermediate_dim: 8192,
norm_eps: 1e-5,
rope_base: 500_000.,
scale_factor: 32,
},
Flavor::Llama100M => Self {
vocab_size: 128256,
num_layers: 4,
num_heads: 8,
num_kv_heads: 2,
embed_dim: 1024,
max_seq_len: 2048,
intermediate_dim: 8192,
norm_eps: 1e-5,
rope_base: 500_000.,
scale_factor: 32,
},
}
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn calculate_default_inv_freq(cfg: &LlamaConfig) -> Vec<f32> {
let head_dim = cfg.embed_dim / cfg.num_heads;
(0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_base.powf(i as f32 / head_dim as f32))
.collect()
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result<Self> {
let low_freq_factor = 1.0;
let high_freq_factor = 4.0;
let original_max_position_embeddings = 8192;
let scale_factor = cfg.scale_factor as f32;
let theta = {
let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
calculate_default_inv_freq(cfg)
.into_iter()
.map(|freq| {
let wavelen = 2. * std::f32::consts::PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / scale_factor
} else {
let smooth = (original_max_position_embeddings as f32 / wavelen
- low_freq_factor)
/ (high_freq_factor - low_freq_factor);
(1. - smooth) * freq / scale_factor + smooth * freq
}
})
.collect::<Vec<_>>()
};
let theta = Tensor::new(theta, dev)?;
let idx_theta = Tensor::arange(0, cfg.max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
let cos = idx_theta.cos()?.to_dtype(dtype)?;
let sin = idx_theta.sin()?.to_dtype(dtype)?;
Ok(Self { cos, sin })
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
fn rms_norm(hidden_size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
let weight = vb.get((hidden_size,), "scale")?;
Ok(RmsNorm::new(weight, eps))
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
num_heads: usize,
head_dim: usize,
num_kv_heads: usize,
num_kv_groups: usize,
}
impl Attention {
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
let head_dim = cfg.embed_dim / cfg.num_heads;
let kv_dim = cfg.num_kv_heads * head_dim;
let q_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("q_proj"))?;
let k_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("k_proj"))?;
let v_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("v_proj"))?;
let o_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("output_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
rotary_emb,
kv_cache: None,
num_heads: cfg.num_heads,
num_kv_heads: cfg.num_kv_heads,
num_kv_groups: cfg.num_heads / cfg.num_kv_heads,
head_dim,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.num_heads * self.head_dim))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct Mlp {
w1: Linear,
w2: Linear,
w3: Linear,
}
impl Mlp {
fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
let w1 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w1"))?;
let w2 = linear_b(cfg.intermediate_dim, cfg.embed_dim, false, vb.pp("w2"))?;
let w3 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w3"))?;
Ok(Self { w1, w2, w3 })
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.w1)?.silu()?;
let rhs = xs.apply(&self.w3)?;
(lhs * rhs)?.apply(&self.w2)
}
}
#[derive(Debug, Clone)]
struct Layer {
mlp_norm: RmsNorm,
sa_norm: RmsNorm,
attn: Attention,
mlp: Mlp,
}
impl Layer {
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
let mlp_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("mlp_norm"))?;
let sa_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("sa_norm"))?;
let attn = Attention::new(cfg, rotary_emb, vb.pp("attn"))?;
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
Ok(Self {
mlp_norm,
sa_norm,
attn,
mlp,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.sa_norm.forward(xs)?;
let xs = self.attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct LlamaModel {
layers: Vec<Layer>,
norm: RmsNorm,
device: Device,
dtype: DType,
}
impl LlamaModel {
pub fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
let mut layers = Vec::with_capacity(cfg.num_layers);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.num_layers {
let layer = Layer::new(cfg, rotary_emb.clone(), vb_l.pp(layer_idx))?;
layers.push(layer);
}
let norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("norm"))?;
Ok(Self {
layers,
norm,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
fn prepare_decoder_attention_mask(
&self,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seq_len, _embed_dim) = xs.dims3()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = xs.clone();
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?;
}
let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?;
Ok(ys)
}
}
#[derive(Debug, Clone)]
pub struct Model {
backbone: LlamaModel,
decoder: LlamaModel,
codebook0_head: Linear,
audio_embeddings: Embedding,
text_embeddings: Embedding,
projection: Linear,
audio_head: Tensor,
config: Config,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let backbone_cfg = LlamaConfig::from_flavor(cfg.backbone_flavor);
let backbone = LlamaModel::new(&backbone_cfg, vb.pp("backbone"))?;
let decoder_cfg = LlamaConfig::from_flavor(cfg.decoder_flavor);
let decoder = LlamaModel::new(&decoder_cfg, vb.pp("decoder"))?;
let backbone_dim = backbone_cfg.embed_dim;
let decoder_dim = decoder_cfg.embed_dim;
let audio_embeddings = embedding(
cfg.audio_vocab_size * cfg.audio_num_codebooks,
backbone_dim,
vb.pp("audio_embeddings"),
)?;
let text_embeddings =
embedding(cfg.text_vocab_size, backbone_dim, vb.pp("text_embeddings"))?;
let projection = linear_b(backbone_dim, decoder_dim, false, vb.pp("projection"))?;
let codebook0_head = linear_b(
backbone_dim,
cfg.audio_vocab_size,
false,
vb.pp("codebook0_head"),
)?;
let audio_head = vb.get(
(
cfg.audio_num_codebooks - 1,
decoder_dim,
cfg.audio_vocab_size,
),
"audio_head",
)?;
Ok(Self {
backbone,
decoder,
codebook0_head,
audio_embeddings,
text_embeddings,
projection,
audio_head,
config: cfg.clone(),
})
}
pub fn clear_kv_cache(&mut self) {
self.backbone.clear_kv_cache();
self.decoder.clear_kv_cache();
}
pub fn generate_frame(
&mut self,
tokens: &Tensor,
tokens_mask: &Tensor,
input_pos: usize,
lp: &mut LogitsProcessor,
) -> Result<Vec<u32>> {
let (b_sz, seq_len, _cb_plus_one) = tokens.dims3()?;
let audio_tokens = tokens.narrow(2, 0, self.config.audio_num_codebooks)?;
let text_tokens = tokens.narrow(2, self.config.audio_num_codebooks, 1)?;
let text_embeds = self.text_embeddings.forward(&text_tokens)?;
let arange = (Tensor::arange(
0u32,
self.config.audio_num_codebooks as u32,
&self.decoder.device,
)? * self.config.audio_vocab_size as f64)?;
let audio_tokens = audio_tokens.broadcast_add(&arange.reshape((1, 1, ()))?)?;
let audio_embeds = self.audio_embeddings.forward(&audio_tokens)?.reshape((
b_sz,
seq_len,
self.config.audio_num_codebooks,
(),
))?;
let embeds = Tensor::cat(&[&audio_embeds, &text_embeds], D::Minus2)?;
let embeds = embeds.broadcast_mul(
&tokens_mask
.to_dtype(self.backbone.dtype)?
.unsqueeze(D::Minus1)?,
)?;
let embeds = embeds.sum(2)?;
let h = self.backbone.forward(&embeds, input_pos)?;
let c0_logits = h.apply(&self.codebook0_head)?;
let c0_sample = lp.sample(&c0_logits.i((0, 0))?)?;
let mut all_samples = vec![c0_sample];
let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?;
let c0_embed = self.audio_embeddings.forward(&c0_sample)?;
let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?;
self.decoder.clear_kv_cache();
let mut decoder_pos = 0;
for i in 1..self.config.audio_num_codebooks {
let proj_h = curr_h.apply(&self.projection)?;
let decoder_h = self.decoder.forward(&proj_h, decoder_pos)?;
decoder_pos += curr_h.dim(1)?;
let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i - 1)?)?;
let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?;
all_samples.push(ci_sample);
let ci_sample = Tensor::from_slice(
&[ci_sample + (i * self.config.audio_vocab_size) as u32],
(1, 1),
&self.decoder.device,
)?;
let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
curr_h = ci_embed
}
Ok(all_samples)
}
pub fn audio_tokens_and_mask(&self, mut frame: Vec<u32>) -> Result<(Tensor, Tensor)> {
let cb = self.config.audio_num_codebooks;
let device = &self.backbone.device;
let mut mask = vec![1u8; cb];
mask.push(0);
let mask = Tensor::from_vec(mask, (1, 1, cb + 1), device)?;
frame.push(0);
let tokens = Tensor::from_vec(frame, (1, 1, cb + 1), device)?;
Ok((tokens, mask))
}
pub fn text_tokens_and_mask(&self, ids: &[u32]) -> Result<(Tensor, Tensor)> {
let cb = self.config.audio_num_codebooks;
let device = &self.backbone.device;
let mut tokens = vec![];
let mut mask = vec![];
for &v in ids.iter() {
let mut token = vec![0; cb];
token.push(v);
let token = Tensor::from_vec(token, (1, 1, cb + 1), device)?;
tokens.push(token);
let mut m = vec![0u8; cb];
m.push(1);
let m = Tensor::from_vec(m, (1, 1, cb + 1), device)?;
mask.push(m);
}
let tokens = Tensor::cat(&tokens, 1)?;
let mask = Tensor::cat(&mask, 1)?;
Ok((tokens, mask))
}
}

View File

@ -104,7 +104,7 @@ impl EncoderBlock {
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
let cfg1 = Conv1dConfig {
stride,
padding: (stride + 1) / 2,
padding: stride.div_ceil(2),
..Default::default()
};
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
@ -196,7 +196,7 @@ impl DecoderBlock {
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
let cfg = ConvTranspose1dConfig {
stride,
padding: (stride + 1) / 2,
padding: stride.div_ceil(2),
..Default::default()
};
let conv_tr1 = encodec::conv_transpose1d_weight_norm(

View File

@ -6,8 +6,8 @@ pub fn get_noise(
width: usize,
device: &Device,
) -> Result<Tensor> {
let height = (height + 15) / 16 * 2;
let width = (width + 15) / 16 * 2;
let height = height.div_ceil(16) * 2;
let width = width.div_ceil(16) * 2;
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
}
@ -84,8 +84,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
let (b, _h_w, c_ph_pw) = xs.dims3()?;
let height = (height + 15) / 16;
let width = (width + 15) / 16;
let height = height.div_ceil(16);
let width = width.div_ceil(16);
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
.reshape((b, c_ph_pw / 4, height * 2, width * 2))

View File

@ -27,7 +27,7 @@ impl Config {
}
fn dt_rank(&self) -> usize {
(self.d_model + 15) / 16
self.d_model.div_ceil(16)
}
fn d_inner(&self) -> usize {

View File

@ -81,6 +81,126 @@ impl Config {
vocab_size: 59514,
}
}
pub fn opus_mt_en_zh() -> Self {
Self {
activation_function: candle_nn::Activation::Swish,
d_model: 512,
decoder_attention_heads: 8,
decoder_ffn_dim: 2048,
decoder_layers: 6,
decoder_start_token_id: 65000,
decoder_vocab_size: Some(65001),
encoder_attention_heads: 8,
encoder_ffn_dim: 2048,
encoder_layers: 6,
eos_token_id: 0,
forced_eos_token_id: 0,
is_encoder_decoder: true,
max_position_embeddings: 512,
pad_token_id: 65000,
scale_embedding: true,
share_encoder_decoder_embeddings: true,
use_cache: true,
vocab_size: 65001,
}
}
pub fn opus_mt_en_hi() -> Self {
Self {
activation_function: candle_nn::Activation::Swish,
d_model: 512,
decoder_attention_heads: 8,
decoder_ffn_dim: 2048,
decoder_layers: 6,
decoder_start_token_id: 61949,
decoder_vocab_size: Some(61950),
encoder_attention_heads: 8,
encoder_ffn_dim: 2048,
encoder_layers: 6,
eos_token_id: 0,
forced_eos_token_id: 0,
is_encoder_decoder: true,
max_position_embeddings: 512,
pad_token_id: 61949,
scale_embedding: true,
share_encoder_decoder_embeddings: true,
use_cache: true,
vocab_size: 61950,
}
}
pub fn opus_mt_en_es() -> Self {
Self {
activation_function: candle_nn::Activation::Swish,
d_model: 512,
decoder_attention_heads: 8,
decoder_ffn_dim: 2048,
decoder_layers: 6,
decoder_start_token_id: 65000,
decoder_vocab_size: Some(65001),
encoder_attention_heads: 8,
encoder_ffn_dim: 2048,
encoder_layers: 6,
eos_token_id: 0,
forced_eos_token_id: 0,
is_encoder_decoder: true,
max_position_embeddings: 512,
pad_token_id: 65000,
scale_embedding: true,
share_encoder_decoder_embeddings: true,
use_cache: true,
vocab_size: 65001,
}
}
pub fn opus_mt_en_fr() -> Self {
Self {
activation_function: candle_nn::Activation::Swish,
d_model: 512,
decoder_attention_heads: 8,
decoder_ffn_dim: 2048,
decoder_layers: 6,
decoder_start_token_id: 59513,
decoder_vocab_size: Some(59514),
encoder_attention_heads: 8,
encoder_ffn_dim: 2048,
encoder_layers: 6,
eos_token_id: 0,
forced_eos_token_id: 0,
is_encoder_decoder: true,
max_position_embeddings: 512,
pad_token_id: 59513,
scale_embedding: true,
share_encoder_decoder_embeddings: true,
use_cache: true,
vocab_size: 59514,
}
}
pub fn opus_mt_en_ru() -> Self {
Self {
activation_function: candle_nn::Activation::Swish,
d_model: 512,
decoder_attention_heads: 8,
decoder_ffn_dim: 2048,
decoder_layers: 6,
decoder_start_token_id: 62517,
decoder_vocab_size: Some(62518),
encoder_attention_heads: 8,
encoder_ffn_dim: 2048,
encoder_layers: 6,
eos_token_id: 0,
forced_eos_token_id: 0,
is_encoder_decoder: true,
max_position_embeddings: 512,
pad_token_id: 62517,
scale_embedding: true,
share_encoder_decoder_embeddings: true,
use_cache: true,
vocab_size: 62518,
}
}
}
#[derive(Debug, Clone)]

View File

@ -716,7 +716,7 @@ pub mod transformer {
None => {
let hidden_dim = self.dim * 4;
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
(n_hidden + 255) / 256 * 256
n_hidden.div_ceil(256) * 256
}
}
}

View File

@ -27,6 +27,7 @@ pub mod codegeex4_9b;
pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod csm;
pub mod dac;
pub mod debertav2;
pub mod deepseek2;

View File

@ -198,7 +198,7 @@ pub fn log_mel_spectrogram_<T: Float>(
let samples = {
let mut samples_padded = samples.to_vec();
let to_add = n_len * fft_step - samples.len();
samples_padded.extend(std::iter::repeat(zero).take(to_add));
samples_padded.extend(std::iter::repeat_n(zero, to_add));
samples_padded
};

View File

@ -177,7 +177,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
let samples = {
let mut samples_padded = samples.to_vec();
let to_add = n_len * fft_step - samples.len();
samples_padded.extend(std::iter::repeat(zero).take(to_add));
samples_padded.extend(std::iter::repeat_n(zero, to_add));
samples_padded
};