mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Compare commits
36 Commits
Author | SHA1 | Date | |
---|---|---|---|
5341bf4cd5 | |||
8977c31b6d | |||
3be12b8b50 | |||
825119ac4b | |||
e319cd78d9 | |||
3fb67e0c2c | |||
d72c44705c | |||
2203f0e3c9 | |||
01e895c1aa | |||
648596c073 | |||
d9904a3baf | |||
d6db305829 | |||
b4daa03e59 | |||
9541467d6b | |||
6429609090 | |||
ba473290da | |||
59c26195db | |||
cb02b389d5 | |||
0d4097031c | |||
10853b803c | |||
f3d472952f | |||
67b85f79f1 | |||
0b24f7f0a4 | |||
3afb04925a | |||
cbf5fc80c2 | |||
468d1d525f | |||
c930ab7e1a | |||
111edbc4ea | |||
e286cf7cc9 | |||
e4ffb85228 | |||
37db86ff79 | |||
add3a714aa | |||
26c16923b9 | |||
9e8bf70333 | |||
ac9cdbd448 | |||
e6cc76fc37 |
BIN
.github/workflows/maturin.yml
vendored
BIN
.github/workflows/maturin.yml
vendored
Binary file not shown.
34
Cargo.toml
34
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.8.3"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,21 +33,21 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.8.3" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.8.3" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.3" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.8.3" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.3" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.8.3" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.8.3" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.8.3" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.4.1"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
hound = "3.5.1"
|
||||
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
|
||||
imageproc = { version = "0.24.0", default-features = false }
|
||||
@ -58,21 +58,21 @@ memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
parquet = { version = "51.0.0" }
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
rand = "0.9.0"
|
||||
rand_distr = "0.5.1"
|
||||
rayon = "1.7.0"
|
||||
safetensors = "0.4.1"
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_plain = "1.0.2"
|
||||
serde_json = "1.0.99"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.19.1", default-features = false }
|
||||
tokenizers = { version = "0.21.0", default-features = false }
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
ug = "0.1.0"
|
||||
ug-cuda = "0.1.0"
|
||||
ug-metal = "0.1.0"
|
||||
ug = "0.2.0"
|
||||
ug-cuda = "0.2.0"
|
||||
ug-metal = "0.2.0"
|
||||
yoke = { version = "0.7.2", features = ["derive"] }
|
||||
zip = { version = "1.1.1", default-features = false }
|
||||
metal = { version = "0.27.0", features = ["mps"]}
|
||||
|
@ -21,7 +21,9 @@ impl BenchDevice for Device {
|
||||
Device::Cpu => Ok(()),
|
||||
Device::Cuda(device) => {
|
||||
#[cfg(feature = "cuda")]
|
||||
return Ok(device.synchronize()?);
|
||||
return Ok(device
|
||||
.synchronize()
|
||||
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ impl Tensor {
|
||||
/// elements having dependencies on the latter ones, e.g. the first element if any is the
|
||||
/// argument.
|
||||
/// This assumes that the op graph is a DAG.
|
||||
fn sorted_nodes(&self) -> Vec<&Tensor> {
|
||||
pub fn sorted_nodes(&self) -> Vec<&Tensor> {
|
||||
// The vec of sorted nodes is passed as an owned value rather than a mutable reference
|
||||
// to get around some lifetime limitations.
|
||||
fn walk<'a>(
|
||||
|
@ -2482,15 +2482,15 @@ impl BackendDevice for CpuDevice {
|
||||
use rand::prelude::*;
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rng = rand::rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
|
||||
}
|
||||
DType::BF16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform =
|
||||
rand::distributions::Uniform::new(bf16::from_f64(min), bf16::from_f64(max));
|
||||
let uniform = rand::distr::Uniform::new(bf16::from_f64(min), bf16::from_f64(max))
|
||||
.map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<bf16, _>(uniform))
|
||||
}
|
||||
@ -2498,8 +2498,8 @@ impl BackendDevice for CpuDevice {
|
||||
}
|
||||
DType::F16 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform =
|
||||
rand::distributions::Uniform::new(f16::from_f64(min), f16::from_f64(max));
|
||||
let uniform = rand::distr::Uniform::new(f16::from_f64(min), f16::from_f64(max))
|
||||
.map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f16, _>(uniform))
|
||||
}
|
||||
@ -2507,7 +2507,8 @@ impl BackendDevice for CpuDevice {
|
||||
}
|
||||
DType::F32 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
|
||||
let uniform =
|
||||
rand::distr::Uniform::new(min as f32, max as f32).map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f32, _>(uniform))
|
||||
}
|
||||
@ -2515,7 +2516,7 @@ impl BackendDevice for CpuDevice {
|
||||
}
|
||||
DType::F64 => {
|
||||
let mut data = Vec::with_capacity(elem_count);
|
||||
let uniform = rand::distributions::Uniform::new(min, max);
|
||||
let uniform = rand::distr::Uniform::new(min, max).map_err(Error::wrap)?;
|
||||
for _i in 0..elem_count {
|
||||
data.push(rng.sample::<f64, _>(uniform))
|
||||
}
|
||||
@ -2528,7 +2529,7 @@ impl BackendDevice for CpuDevice {
|
||||
use rand::prelude::*;
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rng = rand::rng();
|
||||
match dtype {
|
||||
DType::U8 | DType::U32 | DType::I64 => {
|
||||
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
|
||||
|
@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
|
||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||
return Ok(cudnn.clone());
|
||||
}
|
||||
let c = Cudnn::new(dev.cuda_device());
|
||||
let c = Cudnn::new(dev.cuda_stream());
|
||||
if let Ok(c) = &c {
|
||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||
}
|
||||
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
|
||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||
};
|
||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
||||
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||
unsafe {
|
||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||
alg,
|
||||
|
@ -2,8 +2,9 @@ use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||
use half::{bf16, f16};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||
@ -24,10 +25,17 @@ impl DeviceId {
|
||||
struct CudaRng(cudarc::curand::CudaRng);
|
||||
unsafe impl Send for CudaRng {}
|
||||
|
||||
pub struct ModuleStore {
|
||||
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CudaDevice {
|
||||
id: DeviceId,
|
||||
device: Arc<cudarc::driver::CudaDevice>,
|
||||
context: Arc<cudarc::driver::CudaContext>,
|
||||
modules: Arc<std::sync::RwLock<ModuleStore>>,
|
||||
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
|
||||
stream: Arc<cudarc::driver::CudaStream>,
|
||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||
curand: Arc<Mutex<CudaRng>>,
|
||||
}
|
||||
@ -39,16 +47,51 @@ impl std::fmt::Debug for CudaDevice {
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaDevice {
|
||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||
type Target = Arc<cudarc::driver::CudaStream>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.device
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CudaFunc {
|
||||
func: CudaFunction,
|
||||
stream: Arc<cudarc::driver::CudaStream>,
|
||||
}
|
||||
|
||||
impl std::ops::Deref for CudaFunc {
|
||||
type Target = CudaFunction;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.func
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaFunc {
|
||||
pub fn into_cuda_function(self) -> CudaFunction {
|
||||
self.func
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! builder_arg {
|
||||
($b:ident, $($arg:expr),*) => {
|
||||
$(
|
||||
let __arg = $arg;
|
||||
$b.arg(&__arg);
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
impl CudaFunc {
|
||||
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
|
||||
self.stream.launch_builder(&self.func)
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||
self.device.clone()
|
||||
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
|
||||
self.stream.clone()
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@ -56,7 +99,7 @@ impl CudaDevice {
|
||||
&self,
|
||||
func_name: &'static str,
|
||||
kernel: ug::lang::ssa::Kernel,
|
||||
) -> Result<CudaFunction> {
|
||||
) -> Result<CudaFunc> {
|
||||
let mut buf = vec![];
|
||||
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||
let cuda_code = String::from_utf8(buf)?;
|
||||
@ -65,12 +108,12 @@ impl CudaDevice {
|
||||
..Default::default()
|
||||
};
|
||||
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
||||
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
|
||||
let func = match self.device.get_func("ug", func_name) {
|
||||
Some(func) => func,
|
||||
None => crate::bail!("unknown function ug::{func_name}"),
|
||||
};
|
||||
Ok(func)
|
||||
let module = self.context.load_module(ptx).w()?;
|
||||
let func = module.load_function(func_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn id(&self) -> DeviceId {
|
||||
@ -84,57 +127,84 @@ impl CudaDevice {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||
let params = (&data, v as u8, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u8;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||
let params = (&data, v as u32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||
let params = (&data, v as i64, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as i64;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||
let params = (&data, bf16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = bf16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||
let params = (&data, f16::from_f64(v), elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = f16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||
let params = (&data, v as f32, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as f32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||
let params = (&data, v, elem_count);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -144,38 +214,69 @@ impl CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||
if !self.has_func(module_name, module_name) {
|
||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||
// done once per kernel name.
|
||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||
.map_err(|cuda| CudaError::Load {
|
||||
cuda,
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()?;
|
||||
pub fn get_or_load_custom_func(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
module_name: &str,
|
||||
ptx: &str,
|
||||
) -> Result<CudaFunc> {
|
||||
let ms = self.custom_modules.read().unwrap();
|
||||
if let Some(mdl) = ms.get(module_name).as_ref() {
|
||||
let func = mdl.load_function(fn_name).w()?;
|
||||
return Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
});
|
||||
}
|
||||
self.get_func(module_name, module_name)
|
||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||
// able to only build the error value if needed.
|
||||
.ok_or(CudaError::MissingKernel {
|
||||
module_name: module_name.to_string(),
|
||||
})
|
||||
.w()
|
||||
drop(ms);
|
||||
let mut ms = self.custom_modules.write().unwrap();
|
||||
let cuda_module = self.context.load_module(ptx.into()).w()?;
|
||||
ms.insert(module_name.to_string(), cuda_module.clone());
|
||||
let func = cuda_module.load_function(fn_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
|
||||
let ms = self.modules.read().unwrap();
|
||||
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
|
||||
let func = mdl.load_function(fn_name).w()?;
|
||||
return Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
});
|
||||
}
|
||||
drop(ms);
|
||||
let mut ms = self.modules.write().unwrap();
|
||||
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
|
||||
ms.mdls[mdl.index()] = Some(cuda_module.clone());
|
||||
let func = cuda_module.load_function(fn_name).w()?;
|
||||
Ok(CudaFunc {
|
||||
func,
|
||||
stream: self.stream.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl CudaDevice {
|
||||
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||
let stream = context.new_stream().w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||
let module_store = ModuleStore {
|
||||
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||
};
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
context,
|
||||
stream,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -184,14 +285,21 @@ impl BackendDevice for CudaDevice {
|
||||
type Storage = CudaStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||
let stream = context.default_stream();
|
||||
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||
let module_store = ModuleStore {
|
||||
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||
};
|
||||
Ok(Self {
|
||||
id: DeviceId::new(),
|
||||
device,
|
||||
context,
|
||||
stream,
|
||||
blas: Arc::new(blas),
|
||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||
})
|
||||
}
|
||||
|
||||
@ -199,13 +307,13 @@ impl BackendDevice for CudaDevice {
|
||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||
// state will be identical and the same random numbers will be generated.
|
||||
let mut curand = self.curand.lock().unwrap();
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn location(&self) -> crate::DeviceLocation {
|
||||
crate::DeviceLocation::Cuda {
|
||||
gpu_id: self.device.ordinal(),
|
||||
gpu_id: self.context.ordinal(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -373,31 +481,31 @@ impl BackendDevice for CudaDevice {
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let slice = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorageRef::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorageRef::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorageRef::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorageRef::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorageRef::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorageRef::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -410,31 +518,31 @@ impl BackendDevice for CudaDevice {
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_sync_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -447,31 +555,31 @@ impl BackendDevice for CudaDevice {
|
||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||
let slice = match storage {
|
||||
CpuStorage::U8(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
CpuStorage::U32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
CpuStorage::I64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
CpuStorage::BF16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
CpuStorage::F16(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
CpuStorage::F32(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
CpuStorage::F64(storage) => {
|
||||
let data = self.htod_copy(storage).w()?;
|
||||
let data = self.memcpy_stod(&storage).w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
@ -482,7 +590,7 @@ impl BackendDevice for CudaDevice {
|
||||
}
|
||||
|
||||
fn synchronize(&self) -> Result<()> {
|
||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
||||
self.stream.synchronize().map_err(crate::Error::wrap)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -396,7 +396,10 @@ impl UgIOp1 {
|
||||
{
|
||||
let device = device.as_cuda_device()?;
|
||||
let func = device.compile(name, kernel)?;
|
||||
Ok(Self { name, func })
|
||||
Ok(Self {
|
||||
name,
|
||||
func: func.into_cuda_function(),
|
||||
})
|
||||
}
|
||||
#[cfg(feature = "metal")]
|
||||
{
|
||||
@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 {
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
||||
use crate::cuda_backend::WrapErr;
|
||||
use cudarc::driver::LaunchAsync;
|
||||
use cudarc::driver::PushKernelArg;
|
||||
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let stream = sto.device.cuda_stream();
|
||||
// TODO: support more dtypes.
|
||||
let sto = sto.as_cuda_slice::<f32>()?;
|
||||
let sto = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => sto.slice(o1..o2),
|
||||
};
|
||||
let params = (&sto,);
|
||||
let (g, b) = if elem_count % 32 == 0 {
|
||||
(elem_count / 32, 32)
|
||||
} else {
|
||||
@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 {
|
||||
block_dim: (b as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe { self.func.clone().launch(cfg, params) }.w()?;
|
||||
let mut builder = stream.launch_builder(&self.func);
|
||||
builder.arg(&sto);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -45,6 +45,7 @@ pub enum OpCode {
|
||||
BinFloat = b'G',
|
||||
Append = b'a',
|
||||
Appends = b'e',
|
||||
Long1 = 0x8a,
|
||||
}
|
||||
|
||||
// Avoid using FromPrimitive so as not to drag another dependency.
|
||||
@ -84,6 +85,7 @@ impl TryFrom<u8> for OpCode {
|
||||
b'G' => Ok(Self::BinFloat),
|
||||
b'a' => Ok(Self::Append),
|
||||
b'e' => Ok(Self::Appends),
|
||||
0x8a => Ok(Self::Long1),
|
||||
value => Err(value),
|
||||
}
|
||||
}
|
||||
@ -106,6 +108,7 @@ pub enum Object {
|
||||
class_name: String,
|
||||
},
|
||||
Int(i32),
|
||||
Long(i64),
|
||||
Float(f64),
|
||||
Unicode(String),
|
||||
Bool(bool),
|
||||
@ -170,6 +173,14 @@ impl Object {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn int_or_long(self) -> OResult<i64> {
|
||||
match self {
|
||||
Self::Int(t) => Ok(t as i64),
|
||||
Self::Long(t) => Ok(t),
|
||||
_ => Err(self),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tuple(self) -> OResult<Vec<Self>> {
|
||||
match self {
|
||||
Self::Tuple(t) => Ok(t),
|
||||
@ -590,6 +601,15 @@ impl Stack {
|
||||
let obj = self.new_obj(class, args)?;
|
||||
self.push(obj)
|
||||
}
|
||||
OpCode::Long1 => {
|
||||
let n_bytes = r.read_u8()?;
|
||||
let mut v = 0;
|
||||
// Decode the next n bytes in little endian
|
||||
for i in 0..n_bytes {
|
||||
v |= (r.read_u8()? as i64) << (i * 8);
|
||||
}
|
||||
self.push(Object::Long(v))
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
@ -607,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
let mut args = args.tuple()?;
|
||||
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
||||
let size = Vec::<usize>::try_from(args.remove(2))?;
|
||||
let offset = args.remove(1).int()? as usize;
|
||||
let offset = args.remove(1).int_or_long()? as usize;
|
||||
let storage = args.remove(0).persistent_load()?;
|
||||
let mut storage = storage.tuple()?;
|
||||
let storage_size = storage.remove(4).int()? as usize;
|
||||
let storage_size = storage.remove(4).int_or_long()? as usize;
|
||||
let path = storage.remove(2).unicode()?;
|
||||
let (_module_name, class_name) = storage.remove(1).class()?;
|
||||
let dtype = match class_name.as_str() {
|
||||
@ -624,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||
crate::bail!("unsupported storage type {other}")
|
||||
}
|
||||
};
|
||||
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
||||
let layout = Layout::new(
|
||||
crate::Shape::from(size),
|
||||
stride,
|
||||
offset * dtype.size_in_bytes(),
|
||||
);
|
||||
Ok((layout, dtype, path, storage_size))
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
|
||||
use half::f16;
|
||||
|
||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
||||
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct PaddedCudaSlice {
|
||||
@ -50,19 +50,20 @@ fn quantize_q8_1(
|
||||
ky: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let kx = elem_count;
|
||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let params = (src, dst, kx as i32, kx_padded as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(src);
|
||||
builder.arg(dst);
|
||||
barg!(builder, kx as i32, kx_padded as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -72,8 +73,6 @@ fn dequantize_f32(
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||
@ -99,7 +98,7 @@ fn dequantize_f32(
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
@ -110,15 +109,20 @@ fn dequantize_f32(
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let params = (&data.inner, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (&data.inner, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, nb32 as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
@ -129,8 +133,6 @@ fn dequantize_f16(
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||
@ -156,7 +158,7 @@ fn dequantize_f16(
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
@ -167,15 +169,20 @@ fn dequantize_f16(
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let params = (&data.inner, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (&data.inner, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, nb32 as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
@ -188,8 +195,6 @@ fn dequantize_mul_mat_vec(
|
||||
nrows: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
@ -210,7 +215,7 @@ fn dequantize_mul_mat_vec(
|
||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
@ -219,8 +224,12 @@ fn dequantize_mul_mat_vec(
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(y);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, ncols as i32, nrows as i32);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
@ -233,8 +242,6 @@ fn mul_mat_vec_via_q8_1(
|
||||
b_size: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < ncols * nrows {
|
||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||
@ -266,7 +273,7 @@ fn mul_mat_vec_via_q8_1(
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let kernel_name = format!("{kernel_name}{b_size}");
|
||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||
let (nblocks, nwarps) = match b_size {
|
||||
@ -281,16 +288,18 @@ fn mul_mat_vec_via_q8_1(
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (
|
||||
&data.inner,
|
||||
&y_q8_1,
|
||||
&dst,
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&data.inner);
|
||||
builder.arg(&y_q8_1);
|
||||
builder.arg(&dst);
|
||||
barg!(
|
||||
builder,
|
||||
/* ncols_x */ ncols as i32,
|
||||
/* nrows_x */ nrows as i32,
|
||||
/* nrows_y */ ncols_padded as i32,
|
||||
/* nrows_dst */ nrows as i32,
|
||||
/* nrows_dst */ nrows as i32
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
@ -305,8 +314,6 @@ fn mul_mat_via_q8_1(
|
||||
y_cols: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||
if data_elems < x_rows * x_cols {
|
||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||
@ -338,7 +345,7 @@ fn mul_mat_via_q8_1(
|
||||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (
|
||||
@ -350,17 +357,19 @@ fn mul_mat_via_q8_1(
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (
|
||||
/* vx */ &data.inner,
|
||||
/* vy */ &y_q8_1,
|
||||
/* dst */ &dst,
|
||||
let mut builder = func.builder();
|
||||
builder.arg(/* vx */ &data.inner);
|
||||
builder.arg(/* vy */ &y_q8_1);
|
||||
builder.arg(/* dst */ &dst);
|
||||
barg!(
|
||||
builder,
|
||||
/* ncols_x */ x_cols as i32,
|
||||
/* nrows_x */ x_rows as i32,
|
||||
/* ncols_y */ y_cols as i32,
|
||||
/* nrows_y */ k_padded as i32,
|
||||
/* nrows_dst */ x_rows as i32,
|
||||
/* nrows_dst */ x_rows as i32
|
||||
);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
@ -416,7 +425,7 @@ impl QCudaStorage {
|
||||
|
||||
let buffer = self
|
||||
.device
|
||||
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
|
||||
.memcpy_dtov(&self.data.inner.slice(..self.data.len))
|
||||
.w()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
@ -449,7 +458,7 @@ impl QCudaStorage {
|
||||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||
self.device.dtoh_sync_copy(data).w()?
|
||||
self.device.memcpy_dtov(data).w()?
|
||||
}
|
||||
_ => crate::bail!("only f32 can be quantized"),
|
||||
};
|
||||
@ -462,7 +471,7 @@ impl QCudaStorage {
|
||||
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
||||
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
|
||||
self.device
|
||||
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
||||
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
||||
.w()?;
|
||||
self.data = PaddedCudaSlice {
|
||||
inner,
|
||||
@ -599,7 +608,7 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
||||
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
|
||||
device
|
||||
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
|
||||
.memcpy_htod(data, &mut inner.slice_mut(..data.len()))
|
||||
.w()?;
|
||||
Ok(QStorage::Cuda(QCudaStorage {
|
||||
data: PaddedCudaSlice {
|
||||
@ -624,7 +633,7 @@ mod test {
|
||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||
Ok(())
|
||||
}
|
||||
@ -634,7 +643,7 @@ mod test {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||
@ -647,7 +656,7 @@ mod test {
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
assert_eq!(vs.len(), 1);
|
||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||
// Q8 means 1/256 precision.
|
||||
@ -662,7 +671,7 @@ mod test {
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
assert_eq!(vs.len(), 1);
|
||||
assert_eq!(vs[0], 5561851.0);
|
||||
Ok(())
|
||||
@ -673,7 +682,7 @@ mod test {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let ncols = 256;
|
||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
@ -687,7 +696,7 @@ mod test {
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
|
||||
/*
|
||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||
@ -714,7 +723,7 @@ mod test {
|
||||
let dev = CudaDevice::new(0)?;
|
||||
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
||||
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
||||
let y = dev.htod_sync_copy(&vs).w()?;
|
||||
let y = dev.memcpy_stod(&vs).w()?;
|
||||
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||
let cuda_storage = mul_mat_via_q8_1(
|
||||
@ -728,7 +737,7 @@ mod test {
|
||||
&dev,
|
||||
)?;
|
||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
||||
let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -43,43 +43,22 @@ impl From<usize> for Shape {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize,)> for Shape {
|
||||
fn from(d1: (usize,)) -> Self {
|
||||
Self(vec![d1.0])
|
||||
macro_rules! impl_from_tuple {
|
||||
($tuple:ty, $($index:tt),+) => {
|
||||
impl From<$tuple> for Shape {
|
||||
fn from(d: $tuple) -> Self {
|
||||
Self(vec![$(d.$index,)+])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize)> for Shape {
|
||||
fn from(d12: (usize, usize)) -> Self {
|
||||
Self(vec![d12.0, d12.1])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize)> for Shape {
|
||||
fn from(d123: (usize, usize, usize)) -> Self {
|
||||
Self(vec![d123.0, d123.1, d123.2])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize)> for Shape {
|
||||
fn from(d1234: (usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![d1234.0, d1234.1, d1234.2, d1234.3])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d12345: (usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![d12345.0, d12345.1, d12345.2, d12345.3, d12345.4])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(usize, usize, usize, usize, usize, usize)> for Shape {
|
||||
fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self {
|
||||
Self(vec![
|
||||
d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5,
|
||||
])
|
||||
}
|
||||
}
|
||||
impl_from_tuple!((usize,), 0);
|
||||
impl_from_tuple!((usize, usize), 0, 1);
|
||||
impl_from_tuple!((usize, usize, usize), 0, 1, 2);
|
||||
impl_from_tuple!((usize, usize, usize, usize), 0, 1, 2, 3);
|
||||
impl_from_tuple!((usize, usize, usize, usize, usize), 0, 1, 2, 3, 4);
|
||||
impl_from_tuple!((usize, usize, usize, usize, usize, usize), 0, 1, 2, 3, 4, 5);
|
||||
|
||||
impl From<Vec<usize>> for Shape {
|
||||
fn from(dims: Vec<usize>) -> Self {
|
||||
@ -636,4 +615,20 @@ mod tests {
|
||||
let shape = Shape::from((299, 792, 458));
|
||||
assert_eq!(shape.stride_contiguous(), [458 * 792, 458, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_tuple() {
|
||||
let shape = Shape::from((2,));
|
||||
assert_eq!(shape.dims(), &[2]);
|
||||
let shape = Shape::from((2, 3));
|
||||
assert_eq!(shape.dims(), &[2, 3]);
|
||||
let shape = Shape::from((2, 3, 4));
|
||||
assert_eq!(shape.dims(), &[2, 3, 4]);
|
||||
let shape = Shape::from((2, 3, 4, 5));
|
||||
assert_eq!(shape.dims(), &[2, 3, 4, 5]);
|
||||
let shape = Shape::from((2, 3, 4, 5, 6));
|
||||
assert_eq!(shape.dims(), &[2, 3, 4, 5, 6]);
|
||||
let shape = Shape::from((2, 3, 4, 5, 6, 7));
|
||||
assert_eq!(shape.dims(), &[2, 3, 4, 5, 6, 7]);
|
||||
}
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ impl ArgSort {
|
||||
mod cuda {
|
||||
use super::*;
|
||||
use crate::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
||||
use crate::{CudaDevice, WithDType};
|
||||
@ -69,6 +69,8 @@ mod cuda {
|
||||
layout: &crate::Layout,
|
||||
_wrap: W,
|
||||
) -> Result<S> {
|
||||
use cudarc::driver::PushKernelArg;
|
||||
|
||||
let slice = match layout.contiguous_offsets() {
|
||||
None => crate::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
@ -76,20 +78,24 @@ mod cuda {
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||
let func = if self.asc {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
||||
} else {
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
|
||||
};
|
||||
let ncols = self.last_dim;
|
||||
let nrows = elem_count / ncols;
|
||||
let ncols_pad = next_power_of_2(ncols);
|
||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (1, nrows as u32, 1),
|
||||
block_dim: (ncols_pad as u32, 1, 1),
|
||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let stream = dev.cuda_stream();
|
||||
let mut builder = stream.launch_builder(&func);
|
||||
let ncols = ncols as i32;
|
||||
let ncols_pad = ncols_pad as i32;
|
||||
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(S::U32(dst))
|
||||
}
|
||||
}
|
||||
|
@ -2580,6 +2580,28 @@ impl Tensor {
|
||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||
}
|
||||
|
||||
/// Returns a new tensor with the order of elements reversed along the specified dimensions.
|
||||
/// This function makes a copy of the tensor’s data.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, Device};
|
||||
/// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
|
||||
/// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
/// let t_flipped = t.flip(&[0])?;
|
||||
/// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
|
||||
let mut result = self.clone();
|
||||
for &dim in dims.iter() {
|
||||
let size = result.dim(dim)?;
|
||||
let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
|
||||
let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
|
||||
result = result.index_select(&indices_tensor, dim)?;
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -24,6 +24,15 @@ macro_rules! test_device {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
|
||||
assert_eq!(t1.shape(), t2.shape());
|
||||
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
|
||||
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
|
||||
let all_equal = eq_tensor.sum_all()?;
|
||||
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
let t = t.to_vec0::<f32>()?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
#![allow(clippy::approx_constant)]
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
|
||||
|
||||
fn simple_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||
@ -505,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_backprop() -> Result<()> {
|
||||
let device = &Device::Cpu;
|
||||
|
||||
// Create a tensor (leaf node) that requires gradients
|
||||
let x = Var::ones((2, 2), DType::F64, device)?;
|
||||
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
|
||||
|
||||
let y = x.matmul(&weights)?;
|
||||
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
|
||||
|
||||
let z = y.flip(&[1])?;
|
||||
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
|
||||
|
||||
let loss = z.sum_all()?;
|
||||
|
||||
let grad_store = loss.backward()?;
|
||||
let grad_x = grad_store.get_id(x.id()).unwrap();
|
||||
|
||||
let flipped_weights = weights.flip(&[1])?;
|
||||
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
|
||||
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
|
||||
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
|
||||
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(
|
||||
simple_grad,
|
||||
simple_grad_cpu,
|
||||
|
@ -880,10 +880,10 @@ fn get_random_tensors(
|
||||
let mut rng = StdRng::seed_from_u64(314159265358979);
|
||||
|
||||
let lhs = (0..m * k)
|
||||
.map(|_| rng.gen::<f32>() - 0.5)
|
||||
.map(|_| rng.random::<f32>() - 0.5)
|
||||
.collect::<Vec<_>>();
|
||||
let rhs = (0..n * k)
|
||||
.map(|_| rng.gen::<f32>() - 0.5)
|
||||
.map(|_| rng.random::<f32>() - 0.5)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lhs = Tensor::from_vec(lhs, (m, k), device)?;
|
||||
|
@ -1682,3 +1682,54 @@ fn pow() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_1d() -> Result<()> {
|
||||
// 1D: [0, 1, 2, 3, 4]
|
||||
let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?;
|
||||
let flipped = t.flip(&[0])?;
|
||||
// Expected: [4, 3, 2, 1, 0]
|
||||
let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_2d() -> Result<()> {
|
||||
// 2D:
|
||||
// [[0, 1, 2],
|
||||
// [3, 4, 5]]
|
||||
let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?;
|
||||
let flipped = t.flip(&[0, 1])?;
|
||||
// Expected:
|
||||
// [[5, 4, 3],
|
||||
// [2, 1, 0]]
|
||||
let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip_3d_channels() -> Result<()> {
|
||||
// 3D:
|
||||
// [[[0,1,2],
|
||||
// [3,4,5]],
|
||||
//
|
||||
// [[6,7,8],
|
||||
// [9,10,11]]]
|
||||
let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?;
|
||||
let flipped = t.flip(&[2])?;
|
||||
// Expected:
|
||||
// [[[2,1,0],
|
||||
// [5,4,3]],
|
||||
//
|
||||
// [[8,7,6],
|
||||
// [11,10,9]]]
|
||||
let expected = Tensor::from_vec(
|
||||
vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0],
|
||||
(2, 2, 3),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -60,8 +60,8 @@ pub struct DatasetRandomIter<'a> {
|
||||
|
||||
impl<'a> DatasetRandomIter<'a> {
|
||||
pub fn new(ds: &'a Dataset, valid: bool, seq_len: usize, device: Device) -> Self {
|
||||
use rand::rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let all_tokens = if valid {
|
||||
&ds.valid_tokens
|
||||
@ -69,13 +69,13 @@ impl<'a> DatasetRandomIter<'a> {
|
||||
&ds.train_tokens
|
||||
};
|
||||
let mut tokens = all_tokens.iter().collect::<Vec<_>>();
|
||||
tokens.shuffle(&mut thread_rng());
|
||||
tokens.shuffle(&mut rng());
|
||||
let current_tokens = tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = seq_len * 2;
|
||||
let mut indexes_in_bytes = (0..current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
indexes_in_bytes.shuffle(&mut rng());
|
||||
Self {
|
||||
all_tokens,
|
||||
tokens,
|
||||
@ -92,21 +92,21 @@ impl Iterator for DatasetRandomIter<'_> {
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use rand::rng;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::thread_rng;
|
||||
|
||||
let seq_len = self.seq_len;
|
||||
if self.indexes_in_bytes.is_empty() {
|
||||
if self.tokens.is_empty() {
|
||||
self.tokens = self.all_tokens.iter().collect();
|
||||
self.tokens.shuffle(&mut thread_rng());
|
||||
self.tokens.shuffle(&mut rng());
|
||||
}
|
||||
self.current_tokens = self.tokens.pop().unwrap();
|
||||
let seq_len_in_bytes = self.seq_len * 2;
|
||||
self.indexes_in_bytes = (0..self.current_tokens.len() - seq_len_in_bytes)
|
||||
.step_by(seq_len_in_bytes)
|
||||
.collect::<Vec<_>>();
|
||||
self.indexes_in_bytes.shuffle(&mut thread_rng());
|
||||
self.indexes_in_bytes.shuffle(&mut rng());
|
||||
}
|
||||
let start_idx = self.indexes_in_bytes.pop().unwrap();
|
||||
let bytes = &self.current_tokens[start_idx..start_idx + 2 * (seq_len + 1)];
|
||||
|
@ -72,6 +72,8 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
||||
if let parquet::record::Field::Group(subrow) = field {
|
||||
for (_name, field) in subrow.get_column_iter() {
|
||||
if let parquet::record::Field::Bytes(value) = field {
|
||||
// image-rs crate convention is to load in (width, height, channels) order
|
||||
// See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions
|
||||
let image = image::load_from_memory(value.data()).unwrap();
|
||||
buffer_images.extend(image.to_rgb8().as_raw());
|
||||
}
|
||||
@ -81,8 +83,10 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
||||
}
|
||||
}
|
||||
}
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
||||
.to_dtype(DType::U8)?
|
||||
// Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width)
|
||||
let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)?
|
||||
.to_dtype(DType::F32)?
|
||||
.permute((0, 3, 2, 1))?
|
||||
/ 255.)?;
|
||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||
Ok((images, labels))
|
||||
|
13
candle-examples/examples/chatglm/README.md
Normal file
13
candle-examples/examples/chatglm/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
# candle-chatglm
|
||||
|
||||
Uses `THUDM/chatglm3-6b` to generate chinese text. Will not generate text for english (usually).
|
||||
|
||||
## Text Generation
|
||||
|
||||
```bash
|
||||
cargo run --example chatglm --release -- --prompt "部署门槛较低等众多优秀特 "
|
||||
|
||||
> 部署门槛较低等众多优秀特 点,使得其成为了一款备受欢迎的AI助手。
|
||||
>
|
||||
> 作为一款人工智能助手,ChatGLM3-6B
|
||||
```
|
42
candle-examples/examples/chinese_clip/README.md
Normal file
42
candle-examples/examples/chinese_clip/README.md
Normal file
@ -0,0 +1,42 @@
|
||||
# candle-chinese-clip
|
||||
|
||||
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||
pairs of images with related texts. This one is trained using in chinese instead of english.
|
||||
|
||||
## Running on cpu
|
||||
|
||||
```bash
|
||||
$ cargo run --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
|
||||
|
||||
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||
>
|
||||
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
|
||||
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
|
||||
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
|
||||
>
|
||||
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
>
|
||||
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
|
||||
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
|
||||
```
|
||||
|
||||
## Running on metal
|
||||
|
||||
```bash
|
||||
$ cargo run --features metal --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
|
||||
|
||||
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||
>
|
||||
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
|
||||
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
|
||||
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
|
||||
>
|
||||
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
>
|
||||
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
|
||||
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
|
||||
```
|
17
candle-examples/examples/convmixer/README.md
Normal file
17
candle-examples/examples/convmixer/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# candle-convmixer
|
||||
|
||||
A lightweight CNN architecture that processes image patches similar to a vision transformer, with separate spatial and channel convolutions.
|
||||
|
||||
ConvMixer from [Patches Are All You Need?](https://arxiv.org/pdf/2201.09792) and [ConvMixer](https://github.com/locuslab/convmixer).
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example convmixer --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
> mountain bike, all-terrain bike, off-roader: 61.75%
|
||||
> unicycle, monocycle : 5.73%
|
||||
> moped : 3.66%
|
||||
> bicycle-built-for-two, tandem bicycle, tandem: 3.51%
|
||||
> crash helmet : 0.85%
|
||||
```
|
221
candle-examples/examples/csm/main.rs
Normal file
221
candle-examples/examples/csm/main.rs
Normal file
@ -0,0 +1,221 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::csm::{Config, Model};
|
||||
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "1b")]
|
||||
Csm1b,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long, default_value = "[0]Hey how are you doing?")]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 0.7)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "1b")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weights: Option<String>,
|
||||
|
||||
/// The mimi model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
mimi_weights: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Csm1b => "sesame/csm-1b",
|
||||
};
|
||||
name.to_string()
|
||||
}
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let filenames = match args.weights {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => vec![repo.get("model.safetensors")?],
|
||||
};
|
||||
let tokenizer_filename = match args.tokenizer {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => api
|
||||
.model("meta-llama/Llama-3.2-1B".to_string())
|
||||
.get("tokenizer.json")?,
|
||||
};
|
||||
let mimi_filename = match args.mimi_weights {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.model("kyutai/mimi".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = match args.config {
|
||||
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||
None => {
|
||||
let config_file = repo.get("config.json")?;
|
||||
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||
}
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (mut model, device) = {
|
||||
let dtype = DType::F32;
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
(model, device)
|
||||
};
|
||||
let mut mimi_model = {
|
||||
use candle_transformers::models::mimi;
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
|
||||
let config = mimi::Config::v0_1(Some(32));
|
||||
mimi::Model::new(config, vb)?
|
||||
};
|
||||
let cb = config.audio_num_codebooks;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
if args.prompt.ends_with(".safetensors") {
|
||||
let prompt = candle::safetensors::load(args.prompt, &device)?;
|
||||
let mut tokens = prompt
|
||||
.get("tokens")
|
||||
.expect("no tokens in prompt")
|
||||
.to_dtype(DType::U32)?;
|
||||
let mut mask = prompt.get("mask").expect("no mask in prompt").clone();
|
||||
println!("tokens:\n{tokens:?}");
|
||||
println!("mask:\n{mask:?}");
|
||||
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, None, None);
|
||||
let mut const_mask = vec![1u8; cb];
|
||||
const_mask.push(0);
|
||||
let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?;
|
||||
let mut pos = 0;
|
||||
let mut all_tokens = vec![];
|
||||
for i in 0.. {
|
||||
let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||
pos += tokens.dim(1)?;
|
||||
frame.push(0);
|
||||
if frame.iter().all(|&x| x == 0) {
|
||||
break;
|
||||
}
|
||||
println!("frame {i} {pos}:\n{frame:?}");
|
||||
tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?;
|
||||
all_tokens.push(tokens.clone());
|
||||
mask = const_mask.clone();
|
||||
}
|
||||
let all_tokens = Tensor::cat(&all_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
||||
println!("all_tokens:\n{all_tokens:?}");
|
||||
let pcm = mimi_model.decode(&all_tokens)?;
|
||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create("out.wav")?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
} else {
|
||||
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
||||
println!("{prompt:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
17
candle-examples/examples/custom-ops/README.md
Normal file
17
candle-examples/examples/custom-ops/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# candle-custom-ops
|
||||
|
||||
This example illustrates how to implement forward and backward passes for custom operations on the CPU and GPU.
|
||||
The custom op in this example implements RMS normalization for the CPU and CUDA.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example custom-ops
|
||||
|
||||
> [[ 0., 1., 2., 3., 4., 5., 6.],
|
||||
> [ 7., 8., 9., 10., 11., 12., 13.]]
|
||||
> Tensor[[2, 7], f32]
|
||||
> [[0.0000, 0.2773, 0.5547, 0.8320, 1.1094, 1.3867, 1.6641],
|
||||
> [0.6864, 0.7845, 0.8825, 0.9806, 1.0786, 1.1767, 1.2748]]
|
||||
> Tensor[[2, 7], f32]
|
||||
```
|
@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm {
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg};
|
||||
use candle::cuda_backend::WrapErr;
|
||||
let (d1, d2) = layout.shape().dims2()?;
|
||||
let d1 = d1 as u32;
|
||||
@ -69,14 +69,18 @@ impl CustomOp1 for LayerNorm {
|
||||
};
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let params = (&dst, &slice, self.eps, d1, d2);
|
||||
let func =
|
||||
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (d1, 1, 1),
|
||||
block_dim: (d2, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&dst);
|
||||
builder.arg(&slice);
|
||||
candle::builder_arg!(builder, self.eps, d1, d2);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, layout.shape().clone()))
|
||||
|
33
candle-examples/examples/deepseekv2/README.md
Normal file
33
candle-examples/examples/deepseekv2/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# DeepSeek V2
|
||||
|
||||
DeepSeek V2 an MoE model featuring MLA (Multi-Latent Attention). There is a lite (16B) and a full (236B) model.
|
||||
|
||||
- Context length of **32k tokens** (Lite model), **128k tokens** (full model)
|
||||
- 64 routed experts (Lite model), 160 routed experts (full model)
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example deepseekv2 --release --features metal -- --prompt "Recursive fibonacci code in Rust:" --which lite --sample-len 150
|
||||
|
||||
fn fibonacci(n: u32) -> u32 {
|
||||
if n <= 1 {
|
||||
return n;
|
||||
} else {
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
}
|
||||
|
||||
## Fibonacci code in Python:
|
||||
|
||||
def fibonacci(n):
|
||||
if n <= 1:
|
||||
return n
|
||||
else:
|
||||
return fibonacci(n-1) + fibonacci(n-2)
|
||||
|
||||
## Fibonacci code in JavaScript:
|
||||
|
||||
function fibonacci(n) {
|
||||
if (n <= 1
|
||||
```
|
282
candle-examples/examples/deepseekv2/main.rs
Normal file
282
candle-examples/examples/deepseekv2/main.rs
Normal file
@ -0,0 +1,282 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::deepseek2::{DeepSeekV2, DeepSeekV2Config};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: DeepSeekV2,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: DeepSeekV2,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
top_k: Option<usize>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = {
|
||||
let temperature = temp.unwrap_or(0.);
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (top_k, top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(seed, sampling)
|
||||
};
|
||||
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|end▁of▁sentence|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|end▁of▁sentence|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "lite")]
|
||||
Lite,
|
||||
#[value(name = "lite-chat")]
|
||||
LiteChat,
|
||||
#[value(name = "coder-lite-chat")]
|
||||
CoderLiteChat,
|
||||
#[value(name = "v2")]
|
||||
V2,
|
||||
#[value(name = "v2-chat")]
|
||||
V2Chat,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "lite")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => match args.which {
|
||||
Which::CoderLiteChat => "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct".to_string(),
|
||||
Which::LiteChat => "deepseek-ai/DeepSeek-V2-Lite-Chat".to_string(),
|
||||
Which::Lite => "deepseek-ai/DeepSeek-V2-Lite".to_string(),
|
||||
Which::V2 => "deepseek-ai/DeepSeek-V2".to_string(),
|
||||
Which::V2Chat => "deepseek-ai/DeepSeek-V2-Chat".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let filenames = candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?;
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config: DeepSeekV2Config = {
|
||||
let config_file = repo.get("config.json")?;
|
||||
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||
};
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let (model, device) = {
|
||||
let dtype = if device.is_cpu() {
|
||||
DType::F16
|
||||
} else {
|
||||
DType::BF16
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = DeepSeekV2::new(&config, vb)?;
|
||||
(model, device)
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.top_k,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
15
candle-examples/examples/efficientnet/README.md
Normal file
15
candle-examples/examples/efficientnet/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
# candle-efficientnet
|
||||
|
||||
Demonstrates a Candle implementation of EfficientNet for image classification based on ImageNet classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example efficientnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which b1
|
||||
|
||||
> bicycle-built-for-two, tandem bicycle, tandem: 45.85%
|
||||
> mountain bike, all-terrain bike, off-roader: 30.45%
|
||||
> crash helmet : 2.58%
|
||||
> unicycle, monocycle : 2.21%
|
||||
> tricycle, trike, velocipede: 1.53%
|
||||
```
|
@ -1,3 +1,10 @@
|
||||
# candle-falcon
|
||||
|
||||
Falcon is a general large language model.
|
||||
|
||||
## Running an example
|
||||
|
||||
Make sure to include the `--use-f32` flag if using CPU, because there isn't a BFloat16 implementation yet.
|
||||
```
|
||||
cargo run --example falcon --release -- --prompt "Flying monkeys are" --use-f32
|
||||
```
|
@ -9,6 +9,7 @@ use clap::Parser;
|
||||
|
||||
use candle_transformers::models::gemma::{Config as Config1, Model as Model1};
|
||||
use candle_transformers::models::gemma2::{Config as Config2, Model as Model2};
|
||||
use candle_transformers::models::gemma3::{Config as Config3, Model as Model3};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
@ -47,29 +48,16 @@ enum Which {
|
||||
BaseV2_9B,
|
||||
#[value(name = "2-9b-it")]
|
||||
InstructV2_9B,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn is_v1(&self) -> bool {
|
||||
match self {
|
||||
Self::Base2B
|
||||
| Self::Base7B
|
||||
| Self::Instruct2B
|
||||
| Self::Instruct7B
|
||||
| Self::InstructV1_1_2B
|
||||
| Self::InstructV1_1_7B
|
||||
| Self::CodeBase2B
|
||||
| Self::CodeBase7B
|
||||
| Self::CodeInstruct2B
|
||||
| Self::CodeInstruct7B => true,
|
||||
Self::BaseV2_2B | Self::InstructV2_2B | Self::BaseV2_9B | Self::InstructV2_9B => false,
|
||||
}
|
||||
}
|
||||
#[value(name = "3-1b")]
|
||||
BaseV3_1B,
|
||||
#[value(name = "3-1b-it")]
|
||||
InstructV3_1B,
|
||||
}
|
||||
|
||||
enum Model {
|
||||
V1(Model1),
|
||||
V2(Model2),
|
||||
V3(Model3),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -77,6 +65,7 @@ impl Model {
|
||||
match self {
|
||||
Self::V1(m) => m.forward(input_ids, pos),
|
||||
Self::V2(m) => m.forward(input_ids, pos),
|
||||
Self::V3(m) => m.forward(input_ids, pos),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -284,6 +273,8 @@ fn main() -> Result<()> {
|
||||
Which::InstructV2_2B => "google/gemma-2-2b-it".to_string(),
|
||||
Which::BaseV2_9B => "google/gemma-2-9b".to_string(),
|
||||
Which::InstructV2_9B => "google/gemma-2-9b-it".to_string(),
|
||||
Which::BaseV3_1B => "google/gemma-3-1b-pt".to_string(),
|
||||
Which::InstructV3_1B => "google/gemma-3-1b-it".to_string(),
|
||||
},
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
@ -304,7 +295,10 @@ fn main() -> Result<()> {
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
None => match args.which {
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => vec![repo.get("model.safetensors")?],
|
||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
},
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
@ -317,14 +311,31 @@ fn main() -> Result<()> {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = if args.which.is_v1() {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
} else {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
let model = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B => {
|
||||
let config: Config1 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model1::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V1(model)
|
||||
}
|
||||
Which::BaseV2_2B | Which::InstructV2_2B | Which::BaseV2_9B | Which::InstructV2_9B => {
|
||||
let config: Config2 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model2::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V2(model)
|
||||
}
|
||||
Which::BaseV3_1B | Which::InstructV3_1B => {
|
||||
let config: Config3 = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
let model = Model3::new(args.use_flash_attn, &config, vb)?;
|
||||
Model::V3(model)
|
||||
}
|
||||
};
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
@ -12,7 +12,7 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
|
||||
|
||||
** Running with ~cpu~
|
||||
#+begin_src shell
|
||||
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
|
||||
cargo run --example glm4 --release -- --cpu --prompt "Hello world"
|
||||
#+end_src
|
||||
|
||||
** Output Example
|
||||
|
11
candle-examples/examples/llama/README.md
Normal file
11
candle-examples/examples/llama/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
# candle-llama
|
||||
|
||||
Candle implementations of various Llama based architectures.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example llama -- --prompt "Machine learning is " --which v32-3b-instruct
|
||||
|
||||
> Machine learning is the part of computer science which deals with the development of algorithms and
|
||||
```
|
@ -12,6 +12,6 @@ would only work for inference.
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
||||
$ cargo run --example mamba --release -- --prompt "Mamba is the"
|
||||
```
|
||||
|
||||
|
@ -18,21 +18,19 @@ I know you are waiting for me. I will go through the forest, I will go through t
|
||||
mountain. I cannot stay far from you any longer.</s>
|
||||
```
|
||||
|
||||
### Changing model and language pairs
|
||||
|
||||
```bash
|
||||
$ cargo run --example marian-mt --release -- --text "hello, how are you." --which base --language-pair en-zh
|
||||
|
||||
你好,你好吗?
|
||||
```
|
||||
|
||||
## Generating the tokenizer.json files
|
||||
|
||||
You can use the following script to generate the `tokenizer.json` config files
|
||||
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
||||
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
||||
directory.
|
||||
|
||||
```python
|
||||
from convert_slow_tokenizer import MarianConverter
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
||||
```
|
||||
The tokenizer for each `marian-mt` model was trained independently,
|
||||
meaning each new model needs unique tokenizer encoders and decoders.
|
||||
You can use the `./python/convert_slow_tokenizer.py` script in this directory to generate
|
||||
the `tokenizer.json` config files from the hf-hub repos.
|
||||
The script requires all the packages in `./python/requirements.txt` or `./python/uv.lock`
|
||||
to be installed, and has only been tested for `python 3.12.7`.
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -20,6 +20,22 @@ enum Which {
|
||||
Big,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum LanguagePair {
|
||||
#[value(name = "fr-en")]
|
||||
FrEn,
|
||||
#[value(name = "en-zh")]
|
||||
EnZh,
|
||||
#[value(name = "en-hi")]
|
||||
EnHi,
|
||||
#[value(name = "en-es")]
|
||||
EnEs,
|
||||
#[value(name = "en-fr")]
|
||||
EnFr,
|
||||
#[value(name = "en-ru")]
|
||||
EnRu,
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
@ -36,6 +52,10 @@ struct Args {
|
||||
#[arg(long, default_value = "big")]
|
||||
which: Which,
|
||||
|
||||
// Choose which language pair to use
|
||||
#[arg(long, default_value = "fr-en")]
|
||||
language_pair: LanguagePair,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
@ -53,21 +73,43 @@ pub fn main() -> anyhow::Result<()> {
|
||||
use hf_hub::api::sync::Api;
|
||||
let args = Args::parse();
|
||||
|
||||
let config = match args.which {
|
||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
||||
let config = match (args.which, args.language_pair) {
|
||||
(Which::Base, LanguagePair::FrEn) => marian::Config::opus_mt_fr_en(),
|
||||
(Which::Big, LanguagePair::FrEn) => marian::Config::opus_mt_tc_big_fr_en(),
|
||||
(Which::Base, LanguagePair::EnZh) => marian::Config::opus_mt_en_zh(),
|
||||
(Which::Base, LanguagePair::EnHi) => marian::Config::opus_mt_en_hi(),
|
||||
(Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(),
|
||||
(Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(),
|
||||
(Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(),
|
||||
(Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"),
|
||||
};
|
||||
let tokenizer_default_repo = match args.language_pair {
|
||||
LanguagePair::FrEn => "lmz/candle-marian",
|
||||
LanguagePair::EnZh
|
||||
| LanguagePair::EnHi
|
||||
| LanguagePair::EnEs
|
||||
| LanguagePair::EnFr
|
||||
| LanguagePair::EnRu => "KeighBee/candle-marian",
|
||||
};
|
||||
let tokenizer = {
|
||||
let tokenizer = match args.tokenizer {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-fr.json",
|
||||
Which::Big => "tokenizer-marian-fr.json",
|
||||
let filename = match (args.which, args.language_pair) {
|
||||
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-fr.json",
|
||||
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-fr.json",
|
||||
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-en.json",
|
||||
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-en.json",
|
||||
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json",
|
||||
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json",
|
||||
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json",
|
||||
(Which::Big, lp) => {
|
||||
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||
}
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
.model(tokenizer_default_repo.to_string())
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
@ -77,13 +119,21 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let tokenizer = match args.tokenizer_dec {
|
||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||
None => {
|
||||
let name = match args.which {
|
||||
Which::Base => "tokenizer-marian-base-en.json",
|
||||
Which::Big => "tokenizer-marian-en.json",
|
||||
let filename = match (args.which, args.language_pair) {
|
||||
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-en.json",
|
||||
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-en.json",
|
||||
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-zh.json",
|
||||
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-hi.json",
|
||||
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json",
|
||||
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json",
|
||||
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json",
|
||||
(Which::Big, lp) => {
|
||||
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||
}
|
||||
};
|
||||
Api::new()?
|
||||
.model("lmz/candle-marian".to_string())
|
||||
.get(name)?
|
||||
.model(tokenizer_default_repo.to_string())
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||
@ -94,18 +144,48 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let vb = {
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => match args.which {
|
||||
Which::Base => Api::new()?
|
||||
.repo(hf_hub::Repo::with_revision(
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
let api = match (args.which, args.language_pair) {
|
||||
(Which::Base, LanguagePair::FrEn) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/4".to_string(),
|
||||
))
|
||||
.get("model.safetensors")?,
|
||||
Which::Big => Api::new()?
|
||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||
.get("model.safetensors")?,
|
||||
},
|
||||
)),
|
||||
(Which::Big, LanguagePair::FrEn) => {
|
||||
api.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||
}
|
||||
(Which::Base, LanguagePair::EnZh) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-zh".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/13".to_string(),
|
||||
)),
|
||||
(Which::Base, LanguagePair::EnHi) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-hi".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/3".to_string(),
|
||||
)),
|
||||
(Which::Base, LanguagePair::EnEs) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-es".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/4".to_string(),
|
||||
)),
|
||||
(Which::Base, LanguagePair::EnFr) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-fr".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/9".to_string(),
|
||||
)),
|
||||
(Which::Base, LanguagePair::EnRu) => api.repo(hf_hub::Repo::with_revision(
|
||||
"Helsinki-NLP/opus-mt-en-ru".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"refs/pr/7".to_string(),
|
||||
)),
|
||||
(Which::Big, lp) => {
|
||||
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||
}
|
||||
};
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
};
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||
};
|
||||
|
@ -0,0 +1,53 @@
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.convert_slow_tokenizer import SpmConverter, requires_backends, import_protobuf
|
||||
|
||||
class MarianConverter(SpmConverter):
|
||||
def __init__(self, *args, index: int = 0):
|
||||
requires_backends(self, "protobuf")
|
||||
|
||||
super(SpmConverter, self).__init__(*args)
|
||||
|
||||
# from .utils import sentencepiece_model_pb2 as model_pb2
|
||||
model_pb2 = import_protobuf()
|
||||
|
||||
m = model_pb2.ModelProto()
|
||||
print(self.original_tokenizer.spm_files)
|
||||
with open(self.original_tokenizer.spm_files[index], "rb") as f:
|
||||
m.ParseFromString(f.read())
|
||||
self.proto = m
|
||||
print(self.original_tokenizer)
|
||||
#with open(self.original_tokenizer.vocab_path, "r") as f:
|
||||
dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
|
||||
with open(dir_path / "vocab.json", "r") as f:
|
||||
import json
|
||||
self._vocab = json.load(f)
|
||||
|
||||
if self.proto.trainer_spec.byte_fallback:
|
||||
if not getattr(self, "handle_byte_fallback", None):
|
||||
warnings.warn(
|
||||
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
||||
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
||||
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
||||
)
|
||||
|
||||
def vocab(self, proto):
|
||||
vocab_size = max(self._vocab.values()) + 1
|
||||
vocab = [("<NIL>", -100) for _ in range(vocab_size)]
|
||||
for piece in proto.pieces:
|
||||
try:
|
||||
index = self._vocab[piece.piece]
|
||||
except Exception:
|
||||
print(f"Ignored missing piece {piece.piece}")
|
||||
vocab[index] = (piece.piece, piece.score)
|
||||
return vocab
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||
fast_tokenizer.save("tokenizer-marian-base-fr.json")
|
||||
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||
fast_tokenizer.save("tokenizer-marian-base-en.json")
|
22
candle-examples/examples/marian-mt/python/requirements.txt
Normal file
22
candle-examples/examples/marian-mt/python/requirements.txt
Normal file
@ -0,0 +1,22 @@
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
filelock==3.18.0
|
||||
fsspec==2025.3.2
|
||||
huggingface-hub==0.30.1
|
||||
idna==3.10
|
||||
joblib==1.4.2
|
||||
numpy==2.2.4
|
||||
packaging==24.2
|
||||
protobuf==6.30.2
|
||||
pyyaml==6.0.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
sacremoses==0.1.1
|
||||
safetensors==0.5.3
|
||||
sentencepiece==0.2.0
|
||||
tokenizers==0.21.1
|
||||
tqdm==4.67.1
|
||||
transformers==4.50.3
|
||||
typing-extensions==4.13.0
|
||||
urllib3==2.3.0
|
@ -13,6 +13,6 @@ Note that the current candle implementation suffers from some limitations as of
|
||||
## Run an example
|
||||
|
||||
```bash
|
||||
cargo run --example metavoice --release -- \\
|
||||
cargo run --example metavoice --release -- \
|
||||
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
|
||||
```
|
||||
|
@ -16,7 +16,7 @@ use candle_transformers::models::quantized_metavoice::transformer as qtransforme
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use hf_hub::api::sync::Api;
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use rand::{distr::Distribution, SeedableRng};
|
||||
|
||||
pub const ENCODEC_NTOKENS: u32 = 1024;
|
||||
|
||||
@ -250,7 +250,7 @@ fn main() -> Result<()> {
|
||||
let logits = logits.i(step)?.to_dtype(DType::F32)?;
|
||||
let logits = &(&logits / 1.0)?;
|
||||
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
|
||||
let distr = rand::distr::weighted::WeightedIndex::new(prs.as_slice())?;
|
||||
let sample = distr.sample(&mut rng) as u32;
|
||||
codes_.push(sample)
|
||||
}
|
||||
|
16
candle-examples/examples/mnist-training/README.md
Normal file
16
candle-examples/examples/mnist-training/README.md
Normal file
@ -0,0 +1,16 @@
|
||||
# candle-mnist-training
|
||||
|
||||
Training a 2 layer MLP on mnist in Candle.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example mnist-training --features candle-datasets
|
||||
|
||||
> train-images: [60000, 784]
|
||||
> train-labels: [60000]
|
||||
> test-images: [10000, 784]
|
||||
> test-labels: [10000]
|
||||
> 1 train loss: 2.30265 test acc: 68.08%
|
||||
> 2 train loss: 1.50815 test acc: 60.77%
|
||||
```
|
@ -7,6 +7,7 @@ extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use rand::prelude::*;
|
||||
use rand::rng;
|
||||
|
||||
use candle::{DType, Result, Tensor, D};
|
||||
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
|
||||
@ -138,7 +139,7 @@ fn training_loop_cnn(
|
||||
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
|
||||
for epoch in 1..args.epochs {
|
||||
let mut sum_loss = 0f32;
|
||||
batch_idxs.shuffle(&mut thread_rng());
|
||||
batch_idxs.shuffle(&mut rng());
|
||||
for batch_idx in batch_idxs.iter() {
|
||||
let train_images = train_images.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
||||
let train_labels = train_labels.narrow(0, batch_idx * BSIZE, BSIZE)?;
|
||||
|
@ -12,7 +12,7 @@ $ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jp
|
||||
|
||||
Now you can run Moondream from the `candle-examples` crate:
|
||||
```bash
|
||||
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
|
||||
$ cargo run --example moondream --release -- --prompt "Describe the people behind the bikers?" --image "candle-examples/examples/yolo-v8/assets/bike.jpg"
|
||||
|
||||
avavx: false, neon: true, simd128: false, f16c: false
|
||||
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
||||
|
20
candle-examples/examples/musicgen/README.md
Normal file
20
candle-examples/examples/musicgen/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-musicgen
|
||||
|
||||
Candle implementation of musicgen from [Simple and Controllable Music Generation](https://arxiv.org/pdf/2306.05284).
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example musicgen -- --prompt "90s rock song with loud guitars and heavy drums"
|
||||
|
||||
> tokens: [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7, 1]
|
||||
> Tensor[dims 1, 13; u32]
|
||||
> [[[ 0.0902, 0.1256, -0.0585, ..., 0.1057, -0.5141, -0.4675],
|
||||
> [ 0.1972, -0.0268, -0.3368, ..., -0.0495, -0.3597, -0.3940],
|
||||
> [-0.0855, -0.0007, 0.2225, ..., -0.2804, -0.5360, -0.2436],
|
||||
> ...
|
||||
> [ 0.0515, 0.0235, -0.3855, ..., -0.4728, -0.6858, -0.2923],
|
||||
> [-0.3728, -0.1442, -0.1179, ..., -0.4388, -0.0287, -0.3242],
|
||||
> [ 0.0163, 0.0012, -0.0020, ..., 0.0142, 0.0173, -0.0103]]]
|
||||
> Tensor[[1, 13, 768], f32]
|
||||
```
|
@ -148,6 +148,8 @@ enum WhichModel {
|
||||
#[value(name = "3-medium")]
|
||||
V3Medium,
|
||||
#[value(name = "2-old")]
|
||||
V4Mini,
|
||||
#[value(name = "4-mini")]
|
||||
V2Old,
|
||||
PuffinPhiV2,
|
||||
PhiHermes,
|
||||
@ -261,6 +263,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
|
||||
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
|
||||
WhichModel::V4Mini => "microsoft/Phi-4-mini-instruct".to_string(),
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
"lmz/candle-quantized-phi".to_string()
|
||||
}
|
||||
@ -281,6 +284,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2
|
||||
| WhichModel::V3
|
||||
| WhichModel::V3Medium
|
||||
| WhichModel::V4Mini
|
||||
| WhichModel::PuffinPhiV2
|
||||
| WhichModel::PhiHermes => "main".to_string(),
|
||||
}
|
||||
@ -296,7 +300,8 @@ fn main() -> Result<()> {
|
||||
| WhichModel::V2
|
||||
| WhichModel::V2Old
|
||||
| WhichModel::V3
|
||||
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
|
||||
| WhichModel::V3Medium
|
||||
| WhichModel::V4Mini => repo.get("tokenizer.json")?,
|
||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||
}
|
||||
@ -312,19 +317,21 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||
WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
|
||||
WhichModel::V3 | WhichModel::V3Medium | WhichModel::V4Mini => anyhow::bail!(
|
||||
"use the quantized or quantized-phi examples for quantized phi-v3"
|
||||
),
|
||||
}
|
||||
} else {
|
||||
match args.model {
|
||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
|
||||
candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?
|
||||
}
|
||||
WhichModel::V2
|
||||
| WhichModel::V2Old
|
||||
| WhichModel::V3
|
||||
| WhichModel::V3Medium
|
||||
| WhichModel::V4Mini => candle_examples::hub_load_safetensors(
|
||||
&repo,
|
||||
"model.safetensors.index.json",
|
||||
)?,
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
||||
}
|
||||
@ -341,7 +348,7 @@ fn main() -> Result<()> {
|
||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||
WhichModel::V3 | WhichModel::V3Medium => {
|
||||
WhichModel::V3 | WhichModel::V3Medium | WhichModel::V4Mini => {
|
||||
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
|
||||
}
|
||||
};
|
||||
@ -361,7 +368,10 @@ fn main() -> Result<()> {
|
||||
let dtype = match args.dtype {
|
||||
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
||||
None => {
|
||||
if args.model == WhichModel::V3 || args.model == WhichModel::V3Medium {
|
||||
if args.model == WhichModel::V3
|
||||
|| args.model == WhichModel::V3Medium
|
||||
|| args.model == WhichModel::V4Mini
|
||||
{
|
||||
device.bf16_default_to_f32()
|
||||
} else {
|
||||
DType::F32
|
||||
@ -377,7 +387,7 @@ fn main() -> Result<()> {
|
||||
let phi = Phi::new(&config, vb)?;
|
||||
Model::Phi(phi)
|
||||
}
|
||||
WhichModel::V3 | WhichModel::V3Medium => {
|
||||
WhichModel::V3 | WhichModel::V3Medium | WhichModel::V4Mini => {
|
||||
let config_filename = repo.get("config.json")?;
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Phi3Config = serde_json::from_str(&config)?;
|
||||
|
20
candle-examples/examples/quantized-phi/README.md
Normal file
20
candle-examples/examples/quantized-phi/README.md
Normal file
@ -0,0 +1,20 @@
|
||||
# candle-quantized-phi
|
||||
|
||||
Candle implementation of various quantized Phi models.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-phi --release -- --prompt "The best thing about coding in rust is "
|
||||
|
||||
> - it's memory safe (without you having to worry too much)
|
||||
> - the borrow checker is really smart and will catch your mistakes for free, making them show up as compile errors instead of segfaulting in runtime.
|
||||
>
|
||||
> This alone make me prefer using rust over c++ or go, python/Cython etc.
|
||||
>
|
||||
> The major downside I can see now:
|
||||
> - it's slower than other languages (viz: C++) and most importantly lack of libraries to leverage existing work done by community in that language. There are so many useful machine learning libraries available for c++, go, python etc but none for Rust as far as I am aware of on the first glance.
|
||||
> - there aren't a lot of production ready projects which also makes it very hard to start new one (given my background)
|
||||
>
|
||||
> Another downside:
|
||||
```
|
@ -27,6 +27,8 @@ enum Which {
|
||||
W2_7b,
|
||||
#[value(name = "72b")]
|
||||
W2_72b,
|
||||
#[value(name = "deepseekr1-qwen7b")]
|
||||
DeepseekR1Qwen7B,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -102,6 +104,7 @@ impl Args {
|
||||
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
|
||||
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
|
||||
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
|
||||
Which::DeepseekR1Qwen7B => "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
||||
};
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
@ -135,6 +138,11 @@ impl Args {
|
||||
"qwen2-72b-instruct-q4_0.gguf",
|
||||
"main",
|
||||
),
|
||||
Which::DeepseekR1Qwen7B => (
|
||||
"unsloth/DeepSeek-R1-Distill-Qwen-7B-GGUF",
|
||||
"DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf",
|
||||
"main",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
@ -211,11 +219,15 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||
let prompt_str = format!(
|
||||
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
||||
prompt_str
|
||||
);
|
||||
let prompt_str = args
|
||||
.prompt
|
||||
.clone()
|
||||
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||
|
||||
let prompt_str = match args.which {
|
||||
Which::DeepseekR1Qwen7B => format!("<|User|>{prompt_str}<|Assistant|>"),
|
||||
_ => format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n"),
|
||||
};
|
||||
print!("formatted instruct prompt: {}", &prompt_str);
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
@ -260,7 +272,13 @@ fn main() -> anyhow::Result<()> {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
|
||||
|
||||
let eos_token = match args.which {
|
||||
Which::DeepseekR1Qwen7B => "<|end▁of▁sentence|>",
|
||||
_ => "<|im_end|>",
|
||||
};
|
||||
|
||||
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
|
@ -1,5 +1,7 @@
|
||||
# candle-quantized-t5
|
||||
|
||||
Candle implementation for quantizing and running T5 translation models.
|
||||
|
||||
## Seq2Seq example
|
||||
|
||||
This example uses a quantized version of the t5 model.
|
||||
|
@ -75,6 +75,8 @@ enum Which {
|
||||
SmolLM2_360MInstruct,
|
||||
#[value(name = "SmoLM2-1.7B-Instruct")]
|
||||
SmolLM2_1BInstruct,
|
||||
#[value(name = "deepseekr1-llama8b")]
|
||||
DeepseekR1Llama8b,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
@ -94,7 +96,8 @@ impl Which {
|
||||
| Self::L8b
|
||||
| Self::Phi3
|
||||
| Self::SmolLM2_1BInstruct
|
||||
| Self::SmolLM2_360MInstruct => false,
|
||||
| Self::SmolLM2_360MInstruct
|
||||
| Self::DeepseekR1Llama8b => false,
|
||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||
// same way. Starling is a fine tuned version of OpenChat.
|
||||
Self::OpenChat35
|
||||
@ -132,7 +135,8 @@ impl Which {
|
||||
| Self::L8b
|
||||
| Self::SmolLM2_1BInstruct
|
||||
| Self::SmolLM2_360MInstruct
|
||||
| Self::Phi3 => false,
|
||||
| Self::Phi3
|
||||
| Self::DeepseekR1Llama8b => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
}
|
||||
@ -160,11 +164,41 @@ impl Which {
|
||||
| Self::L8b
|
||||
| Self::SmolLM2_1BInstruct
|
||||
| Self::SmolLM2_360MInstruct
|
||||
| Self::Phi3 => false,
|
||||
| Self::Phi3
|
||||
| Self::DeepseekR1Llama8b => false,
|
||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_deepseek(&self) -> bool {
|
||||
match self {
|
||||
Self::L7b
|
||||
| Self::L13b
|
||||
| Self::L70b
|
||||
| Self::L7bChat
|
||||
| Self::L13bChat
|
||||
| Self::L70bChat
|
||||
| Self::L7bCode
|
||||
| Self::L13bCode
|
||||
| Self::L34bCode
|
||||
| Self::Leo7b
|
||||
| Self::Leo13b
|
||||
| Self::Mixtral
|
||||
| Self::MixtralInstruct
|
||||
| Self::Mistral7b
|
||||
| Self::Mistral7bInstruct
|
||||
| Self::Mistral7bInstructV02
|
||||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::L8b
|
||||
| Self::SmolLM2_1BInstruct
|
||||
| Self::SmolLM2_360MInstruct
|
||||
| Self::Phi3
|
||||
| Self::OpenChat35
|
||||
| Self::Starling7bAlpha => false,
|
||||
Self::DeepseekR1Llama8b => true,
|
||||
}
|
||||
}
|
||||
fn tokenizer_repo(&self) -> &'static str {
|
||||
match self {
|
||||
Self::L7b
|
||||
@ -191,6 +225,7 @@ impl Which {
|
||||
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
||||
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
Self::DeepseekR1Llama8b => "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -363,6 +398,10 @@ impl Args {
|
||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
|
||||
"smollm2-1.7b-instruct-q4_k_m.gguf",
|
||||
),
|
||||
Which::DeepseekR1Llama8b => (
|
||||
"unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
|
||||
"DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf",
|
||||
),
|
||||
};
|
||||
let revision = if self.which == Which::Phi3 {
|
||||
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
|
||||
@ -477,6 +516,7 @@ fn main() -> anyhow::Result<()> {
|
||||
| Which::L8b
|
||||
| Which::SmolLM2_1BInstruct
|
||||
| Which::SmolLM2_360MInstruct
|
||||
| Which::DeepseekR1Llama8b
|
||||
| Which::Phi3 => 1,
|
||||
Which::Mixtral
|
||||
| Which::MixtralInstruct
|
||||
@ -530,6 +570,8 @@ fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
} else if args.which.is_mistral() {
|
||||
format!("[INST] {prompt} [/INST]")
|
||||
} else if args.which.is_deepseek() {
|
||||
format!("<|User|>{prompt}<|Assistant|>")
|
||||
} else {
|
||||
prompt
|
||||
}
|
||||
@ -597,6 +639,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let eos_token = match args.which {
|
||||
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
|
||||
Which::L8b => "<|end_of_text|>",
|
||||
Which::DeepseekR1Llama8b => "<|end▁of▁sentence|>",
|
||||
_ => match args.which.is_open_chat() {
|
||||
true => "<|end_of_turn|>",
|
||||
false => "</s>",
|
||||
|
@ -2,6 +2,11 @@
|
||||
|
||||
Reinforcement Learning examples for candle.
|
||||
|
||||
> [!WARNING]
|
||||
> uv is not currently compatible with pyo3 as of 2025/3/28.
|
||||
|
||||
## System wide python
|
||||
|
||||
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
||||
Python package with:
|
||||
```bash
|
||||
|
@ -5,7 +5,7 @@ use candle_nn::{
|
||||
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
|
||||
VarBuilder, VarMap,
|
||||
};
|
||||
use rand::{distributions::Uniform, thread_rng, Rng};
|
||||
use rand::{distr::Uniform, rng, Rng};
|
||||
|
||||
use super::gym_env::GymEnv;
|
||||
|
||||
@ -103,8 +103,8 @@ impl ReplayBuffer {
|
||||
if self.size < batch_size {
|
||||
Ok(None)
|
||||
} else {
|
||||
let transitions: Vec<&Transition> = thread_rng()
|
||||
.sample_iter(Uniform::from(0..self.size))
|
||||
let transitions: Vec<&Transition> = rng()
|
||||
.sample_iter(Uniform::try_from(0..self.size).map_err(Error::wrap)?)
|
||||
.take(batch_size)
|
||||
.map(|i| self.buffer.get(i).unwrap())
|
||||
.collect();
|
||||
@ -498,11 +498,11 @@ pub fn run() -> Result<()> {
|
||||
OuNoise::new(MU, THETA, SIGMA, size_action)?,
|
||||
)?;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rng = rand::rng();
|
||||
|
||||
for episode in 0..MAX_EPISODES {
|
||||
// let mut state = env.reset(episode as u64)?;
|
||||
let mut state = env.reset(rng.gen::<u64>())?;
|
||||
let mut state = env.reset(rng.random::<u64>())?;
|
||||
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
@ -538,7 +538,7 @@ pub fn run() -> Result<()> {
|
||||
agent.train = false;
|
||||
for episode in 0..10 {
|
||||
// let mut state = env.reset(episode as u64)?;
|
||||
let mut state = env.reset(rng.gen::<u64>())?;
|
||||
let mut state = env.reset(rng.random::<u64>())?;
|
||||
let mut total_reward = 0.0;
|
||||
for _ in 0..EPISODE_LENGTH {
|
||||
let mut action = 2.0 * agent.actions(&state)?;
|
||||
|
@ -1,9 +1,8 @@
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use rand::distributions::Uniform;
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand::{distr::Uniform, rng, Rng};
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor};
|
||||
use candle::{DType, Device, Error, Module, Result, Tensor};
|
||||
use candle_nn::loss::mse;
|
||||
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, VarBuilder, VarMap};
|
||||
|
||||
@ -65,8 +64,8 @@ pub fn run() -> Result<()> {
|
||||
// fed to the model so that it performs a backward pass.
|
||||
if memory.len() > BATCH_SIZE {
|
||||
// Sample randomly from the memory.
|
||||
let batch = thread_rng()
|
||||
.sample_iter(Uniform::from(0..memory.len()))
|
||||
let batch = rng()
|
||||
.sample_iter(Uniform::try_from(0..memory.len()).map_err(Error::wrap)?)
|
||||
.take(BATCH_SIZE)
|
||||
.map(|i| memory.get(i).unwrap().clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
@ -4,7 +4,7 @@ use candle_nn::{
|
||||
linear, ops::log_softmax, ops::softmax, sequential::seq, Activation, AdamW, Optimizer,
|
||||
ParamsAdamW, VarBuilder, VarMap,
|
||||
};
|
||||
use rand::{distributions::Distribution, rngs::ThreadRng, Rng};
|
||||
use rand::{distr::Distribution, rngs::ThreadRng, Rng};
|
||||
|
||||
fn new_model(
|
||||
input_shape: &[usize],
|
||||
@ -39,7 +39,7 @@ fn accumulate_rewards(steps: &[Step<i64>]) -> Vec<f64> {
|
||||
}
|
||||
|
||||
fn weighted_sample(probs: Vec<f32>, rng: &mut ThreadRng) -> Result<usize> {
|
||||
let distribution = rand::distributions::WeightedIndex::new(probs).map_err(Error::wrap)?;
|
||||
let distribution = rand::distr::weighted::WeightedIndex::new(probs).map_err(Error::wrap)?;
|
||||
let mut rng = rng;
|
||||
Ok(distribution.sample(&mut rng))
|
||||
}
|
||||
@ -65,10 +65,10 @@ pub fn run() -> Result<()> {
|
||||
|
||||
let mut optimizer = AdamW::new(varmap.all_vars(), optimizer_params)?;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut rng = rand::rng();
|
||||
|
||||
for epoch_idx in 0..100 {
|
||||
let mut state = env.reset(rng.gen::<u64>())?;
|
||||
let mut state = env.reset(rng.random::<u64>())?;
|
||||
let mut steps: Vec<Step<i64>> = vec![];
|
||||
|
||||
loop {
|
||||
@ -84,7 +84,7 @@ pub fn run() -> Result<()> {
|
||||
steps.push(step.copy_with_obs(&state));
|
||||
|
||||
if step.terminated || step.truncated {
|
||||
state = env.reset(rng.gen::<u64>())?;
|
||||
state = env.reset(rng.random::<u64>())?;
|
||||
if steps.len() > 5000 {
|
||||
break;
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ probabilities for the top-5 classes.
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example resnet --release -- --image tiger.jpg
|
||||
$ cargo run --example resnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
|
@ -10,9 +10,11 @@ If you want you can use the example images from this [pull request][pr], downloa
|
||||
|
||||
```bash
|
||||
# run the image classification task
|
||||
cargo run --example segformer classify <path-to-image>
|
||||
cargo run --example segformer classify candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
# run the segmentation task
|
||||
cargo run --example segformer segment <path-to-image>
|
||||
cargo run --example segformer segment candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
```
|
||||
|
||||
Example output for classification:
|
||||
|
@ -14,8 +14,8 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
||||
|
||||
```bash
|
||||
cargo run --example segment-anything --release -- \
|
||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
--use-tiny
|
||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg \
|
||||
--use-tiny \
|
||||
--point 0.6,0.6 --point 0.6,0.55
|
||||
```
|
||||
|
||||
|
@ -5,7 +5,7 @@ SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmo
|
||||
|
||||
### Running an example
|
||||
```
|
||||
$ cargo run --features cuda -r --example siglip -
|
||||
$ cargo run --features cuda -r --example siglip
|
||||
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
|
||||
|
||||
|
||||
|
@ -13,11 +13,40 @@ use candle_transformers::models::siglip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
|
||||
enum Which {
|
||||
#[value(name = "v1-base-patch16-224")]
|
||||
V1BasePatch16_224,
|
||||
#[value(name = "v2-base-patch16-224")]
|
||||
V2BasePatch16_224,
|
||||
#[value(name = "v2-base-patch16-256")]
|
||||
V2BasePatch16_256,
|
||||
#[value(name = "v2-base-patch16-384")]
|
||||
V2BasePatch16_384,
|
||||
#[value(name = "v2-base-patch16-512")]
|
||||
V2BasePatch16_512,
|
||||
#[value(name = "v2-large-patch16-256")]
|
||||
V2LargePatch16_256,
|
||||
#[value(name = "v2-large-patch16-384")]
|
||||
V2LargePatch16_384,
|
||||
#[value(name = "v2-large-patch16-512")]
|
||||
V2LargePatch16_512,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
hf_repo: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "v1-base-patch16-224")]
|
||||
which: Which,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
@ -66,16 +95,37 @@ fn load_images<T: AsRef<std::path::Path>>(
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let hf_repo = match args.hf_repo.as_ref() {
|
||||
Some(hf_repo) => hf_repo,
|
||||
None => match args.which {
|
||||
Which::V1BasePatch16_224 => "google/siglip-base-patch16-224",
|
||||
Which::V2BasePatch16_224 => "google/siglip2-base-patch16-224",
|
||||
Which::V2BasePatch16_256 => "google/siglip2-base-patch16-256",
|
||||
Which::V2BasePatch16_384 => "google/siglip2-base-patch16-384",
|
||||
Which::V2BasePatch16_512 => "google/siglip2-base-patch16-512",
|
||||
Which::V2LargePatch16_256 => "google/siglip2-large-patch16-256",
|
||||
Which::V2LargePatch16_384 => "google/siglip2-large-patch16-384",
|
||||
Which::V2LargePatch16_512 => "google/siglip2-large-patch16-512",
|
||||
},
|
||||
};
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("google/siglip-base-patch16-224".to_string());
|
||||
let api = api.model(hf_repo.to_string());
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let tokenizer = get_tokenizer(args.tokenizer)?;
|
||||
let config = siglip::Config::base_patch16_224();
|
||||
let config_file = match args.config {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(hf_repo.to_string());
|
||||
api.get("config.json")?
|
||||
}
|
||||
Some(config) => config.into(),
|
||||
};
|
||||
let tokenizer = get_tokenizer(hf_repo, args.tokenizer)?;
|
||||
let config: siglip::Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vec_imgs = match args.images {
|
||||
Some(imgs) => imgs,
|
||||
@ -114,11 +164,11 @@ pub fn main() -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||
pub fn get_tokenizer(hf_repo: &str, tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer = match tokenizer {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model("google/siglip-base-patch16-224".to_string());
|
||||
let api = api.model(hf_repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
Some(file) => file.into(),
|
||||
|
@ -6,7 +6,14 @@ This example uses the models available in the hugging face [onnx-community/siler
|
||||
|
||||
## Running the example
|
||||
|
||||
### using arecord
|
||||
|
||||
```bash
|
||||
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
||||
```
|
||||
|
||||
### using SoX
|
||||
|
||||
```bash
|
||||
$ rec -t raw -r 48000 -b 16 -c 1 -e signed-integer - trim 0 5 | sox -t raw -r 48000 -b 16 -c 1 -e signed-integer - -t raw -r 16000 -b 16 -c 1 -e signed-integer - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
||||
```
|
||||
|
@ -617,7 +617,7 @@ fn run(args: Args) -> Result<()> {
|
||||
let mut scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let device = candle_examples::device(cpu)?;
|
||||
// If a seed is not given, generate a random seed and print it
|
||||
let seed = seed.unwrap_or(rand::thread_rng().gen_range(0u64..u64::MAX));
|
||||
let seed = seed.unwrap_or(rand::rng().random_range(0u64..u64::MAX));
|
||||
println!("Using seed {seed}");
|
||||
device.set_seed(seed)?;
|
||||
let use_guide_scale = guidance_scale > 1.0;
|
||||
|
15
candle-examples/examples/starcoder2/README.md
Normal file
15
candle-examples/examples/starcoder2/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
# candle-starcoder2
|
||||
|
||||
Candle implementation of Star Coder 2 family of code generation model from [StarCoder 2 and The Stack v2: The Next Generation](https://arxiv.org/pdf/2402.19173).
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example starcoder2 -- --prompt "write a recursive fibonacci function in python "
|
||||
|
||||
> # that returns the nth number in the sequence.
|
||||
>
|
||||
> def fib(n):
|
||||
> if n
|
||||
|
||||
```
|
@ -10,7 +10,7 @@ Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. T
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?"
|
||||
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?" --which 1.5b
|
||||
|
||||
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
|
||||
> Tensor[[1, 1024], f32]
|
||||
|
@ -1,5 +1,7 @@
|
||||
# candle-t5
|
||||
|
||||
Candle implementations of the T5 family of translation models.
|
||||
|
||||
## Encoder-decoder example:
|
||||
|
||||
```bash
|
||||
|
@ -7,7 +7,7 @@ The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main
|
||||
You can run the example with the following command:
|
||||
|
||||
```bash
|
||||
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
|
||||
cargo run --example vgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which vgg13
|
||||
```
|
||||
|
||||
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).
|
||||
|
@ -7,8 +7,8 @@ probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example vit --release -- --image tiger.jpg
|
||||
```bash
|
||||
$ cargo run --example vit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
|
15
candle-examples/examples/whisper-microphone/README.md
Normal file
15
candle-examples/examples/whisper-microphone/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
# candle-whisper-microphone
|
||||
|
||||
Whisper implementation using microphone as input.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example whisper-microphone --features microphone
|
||||
|
||||
> transcribing audio...
|
||||
> 480256 160083
|
||||
> language_token: None
|
||||
> 0.0s -- 30.0s: Hello, hello, I don't know if this is working, but You know, how long did I make this?
|
||||
> 480256 160085
|
||||
```
|
@ -9,7 +9,7 @@ use candle::{Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use rand::{distr::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
@ -204,7 +204,7 @@ impl Decoder {
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
|
@ -14,7 +14,9 @@ use candle::{Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use rand::distr::weighted::WeightedIndex;
|
||||
use rand::distr::Distribution;
|
||||
use rand::SeedableRng;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod multilingual;
|
||||
@ -208,7 +210,7 @@ impl Decoder {
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
let distr = WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
|
13
candle-examples/examples/yi/README.md
Normal file
13
candle-examples/examples/yi/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
# candle-yi
|
||||
|
||||
Candle implentations of the Yi family of bilingual (English, Chinese) LLMs.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example yi -- --prompt "Here is a test sentence"
|
||||
|
||||
> python
|
||||
> print("Hello World")
|
||||
>
|
||||
```
|
32
candle-examples/examples/yolo-v3/README.md
Normal file
32
candle-examples/examples/yolo-v3/README.md
Normal file
@ -0,0 +1,32 @@
|
||||
# candle-yolo-v3:
|
||||
|
||||
Candle implementation of Yolo-V3 for object detection.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example yolo-v3 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||
|
||||
> generated predictions Tensor[dims 10647, 85; f32]
|
||||
> person: Bbox { xmin: 46.362198, ymin: 72.177, xmax: 135.92522, ymax: 339.8356, confidence: 0.99705493, data: () }
|
||||
> person: Bbox { xmin: 137.25645, ymin: 67.58148, xmax: 216.90437, ymax: 333.80756, confidence: 0.9898516, data: () }
|
||||
> person: Bbox { xmin: 245.7842, ymin: 82.76726, xmax: 316.79053, ymax: 337.21613, confidence: 0.9884322, data: () }
|
||||
> person: Bbox { xmin: 207.52783, ymin: 61.815224, xmax: 266.77884, ymax: 307.92606, confidence: 0.9860648, data: () }
|
||||
> person: Bbox { xmin: 11.457404, ymin: 60.335564, xmax: 34.39357, ymax: 187.7714, confidence: 0.9545012, data: () }
|
||||
> person: Bbox { xmin: 251.88353, ymin: 11.235481, xmax: 286.56607, ymax: 92.54697, confidence: 0.8439807, data: () }
|
||||
> person: Bbox { xmin: -0.44309902, ymin: 55.486923, xmax: 13.160354, ymax: 184.09705, confidence: 0.8266243, data: () }
|
||||
> person: Bbox { xmin: 317.40826, ymin: 55.39501, xmax: 370.6704, ymax: 153.74887, confidence: 0.7327442, data: () }
|
||||
> person: Bbox { xmin: 370.02835, ymin: 66.120224, xmax: 404.22824, ymax: 142.09691, confidence: 0.7265741, data: () }
|
||||
> person: Bbox { xmin: 250.36511, ymin: 57.349842, xmax: 280.06335, ymax: 116.29384, confidence: 0.709422, data: () }
|
||||
> person: Bbox { xmin: 32.573215, ymin: 66.66239, xmax: 50.49056, ymax: 173.42068, confidence: 0.6998766, data: () }
|
||||
> person: Bbox { xmin: 131.72215, ymin: 63.946213, xmax: 166.66151, ymax: 241.52773, confidence: 0.64457536, data: () }
|
||||
> person: Bbox { xmin: 407.42416, ymin: 49.106407, xmax: 415.24307, ymax: 84.7134, confidence: 0.5955802, data: () }
|
||||
> person: Bbox { xmin: 51.650482, ymin: 64.4985, xmax: 67.40904, ymax: 106.952385, confidence: 0.5196007, data: () }
|
||||
> bicycle: Bbox { xmin: 160.10031, ymin: 183.90837, xmax: 200.86832, ymax: 398.609, confidence: 0.9623588, data: () }
|
||||
> bicycle: Bbox { xmin: 66.570915, ymin: 192.56966, xmax: 112.06765, ymax: 369.28497, confidence: 0.9174347, data: () }
|
||||
> bicycle: Bbox { xmin: 258.2856, ymin: 197.04532, xmax: 298.43106, ymax: 364.8627, confidence: 0.6851388, data: () }
|
||||
> bicycle: Bbox { xmin: 214.0034, ymin: 175.76498, xmax: 252.45158, ymax: 356.53818, confidence: 0.67071193, data: () }
|
||||
> motorbike: Bbox { xmin: 318.23938, ymin: 95.22487, xmax: 369.9743, ymax: 213.46263, confidence: 0.96691036, data: () }
|
||||
> motorbike: Bbox { xmin: 367.46417, ymin: 100.07982, xmax: 394.9981, ymax: 174.6545, confidence: 0.9185384, data: () }
|
||||
> writing "candle-examples/examples/yolo-v8/assets/bike.pp.jpg"
|
||||
```
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.8.3"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.3" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -88,19 +88,26 @@ fn main() -> Result<()> {
|
||||
.arg("--use_fast_math")
|
||||
.arg("--verbose");
|
||||
|
||||
let mut is_target_msvc = false;
|
||||
if let Ok(target) = std::env::var("TARGET") {
|
||||
if target.contains("msvc") {
|
||||
is_target_msvc = true;
|
||||
builder = builder.arg("-D_USE_MATH_DEFINES");
|
||||
}
|
||||
}
|
||||
|
||||
if !is_target_msvc {
|
||||
builder = builder.arg("-Xcompiler").arg("-fPIC");
|
||||
}
|
||||
|
||||
let out_file = build_dir.join("libflashattention.a");
|
||||
builder.build_lib(out_file);
|
||||
|
||||
println!("cargo:rustc-link-search={}", build_dir.display());
|
||||
println!("cargo:rustc-link-lib=flashattention");
|
||||
println!("cargo:rustc-link-lib=dylib=cudart");
|
||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||
|
||||
if !is_target_msvc {
|
||||
println!("cargo:rustc-link-lib=dylib=stdc++");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -88,6 +88,7 @@ impl FlashAttn {
|
||||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||
}
|
||||
|
||||
let stream = dev.cuda_stream();
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
@ -114,7 +115,9 @@ impl FlashAttn {
|
||||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
// Dropping the guard here doesn't seem very safe.
|
||||
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||
ptr as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
@ -161,17 +164,17 @@ impl FlashAttn {
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
||||
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||
ffi::run_mha(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
q_ptr as *const core::ffi::c_void,
|
||||
k_ptr as *const core::ffi::c_void,
|
||||
v_ptr as *const core::ffi::c_void,
|
||||
dst_ptr as *const core::ffi::c_void,
|
||||
softmax_lse_ptr as *const core::ffi::c_void,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||
@ -550,6 +553,7 @@ impl FlashAttnVarLen {
|
||||
|
||||
let batch_size = nseqlens_q - 1;
|
||||
|
||||
let stream = dev.cuda_stream();
|
||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||
if alibi_slopes.dtype() != DType::F32 {
|
||||
candle::bail!(
|
||||
@ -576,7 +580,9 @@ impl FlashAttnVarLen {
|
||||
|
||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||
|
||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
||||
// Dropping the guard here doesn't seem very safe.
|
||||
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||
ptr as *const core::ffi::c_void
|
||||
} else {
|
||||
std::ptr::null()
|
||||
};
|
||||
@ -621,22 +627,22 @@ impl FlashAttnVarLen {
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
||||
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
|
||||
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
|
||||
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
|
||||
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
|
||||
ffi::run_mha(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
dst_ptr,
|
||||
softmax_lse_ptr,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
||||
q_ptr as *const core::ffi::c_void,
|
||||
k_ptr as *const core::ffi::c_void,
|
||||
v_ptr as *const core::ffi::c_void,
|
||||
dst_ptr as *const core::ffi::c_void,
|
||||
softmax_lse_ptr as *const core::ffi::c_void,
|
||||
/* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void,
|
||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
|
||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
|
||||
/* q_batch_stride */ 0,
|
||||
/* k_batch_stride */ 0,
|
||||
/* v_batch_stride */ 0,
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.8.3"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -7,5 +7,5 @@ fn main() {
|
||||
let builder = bindgen_cuda::Builder::default();
|
||||
println!("cargo:info={builder:?}");
|
||||
let bindings = builder.build_ptx().unwrap();
|
||||
bindings.write("src/lib.rs").unwrap();
|
||||
bindings.write("src/ptx.rs").unwrap();
|
||||
}
|
||||
|
@ -1,11 +1,78 @@
|
||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
mod ptx;
|
||||
|
||||
#[repr(u32)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum Id {
|
||||
Affine,
|
||||
Binary,
|
||||
Cast,
|
||||
Conv,
|
||||
Fill,
|
||||
Indexing,
|
||||
Quantized,
|
||||
Reduce,
|
||||
Sort,
|
||||
Ternary,
|
||||
Unary,
|
||||
}
|
||||
|
||||
pub const ALL_IDS: [Id; 11] = [
|
||||
Id::Affine,
|
||||
Id::Binary,
|
||||
Id::Cast,
|
||||
Id::Conv,
|
||||
Id::Fill,
|
||||
Id::Indexing,
|
||||
Id::Quantized,
|
||||
Id::Reduce,
|
||||
Id::Sort,
|
||||
Id::Ternary,
|
||||
Id::Unary,
|
||||
];
|
||||
|
||||
pub struct Module {
|
||||
index: usize,
|
||||
ptx: &'static str,
|
||||
}
|
||||
|
||||
impl Module {
|
||||
pub fn index(&self) -> usize {
|
||||
self.index
|
||||
}
|
||||
|
||||
pub fn ptx(&self) -> &'static str {
|
||||
self.ptx
|
||||
}
|
||||
}
|
||||
|
||||
const fn module_index(id: Id) -> usize {
|
||||
let mut i = 0;
|
||||
while i < ALL_IDS.len() {
|
||||
if ALL_IDS[i] as u32 == id as u32 {
|
||||
return i;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
panic!("id not found")
|
||||
}
|
||||
|
||||
macro_rules! mdl {
|
||||
($cst:ident, $id:ident) => {
|
||||
pub const $cst: Module = Module {
|
||||
index: module_index(Id::$id),
|
||||
ptx: ptx::$cst,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
mdl!(AFFINE, Affine);
|
||||
mdl!(BINARY, Binary);
|
||||
mdl!(CAST, Cast);
|
||||
mdl!(CONV, Conv);
|
||||
mdl!(FILL, Fill);
|
||||
mdl!(INDEXING, Indexing);
|
||||
mdl!(QUANTIZED, Quantized);
|
||||
mdl!(REDUCE, Reduce);
|
||||
mdl!(SORT, Sort);
|
||||
mdl!(TERNARY, Ternary);
|
||||
mdl!(UNARY, Unary);
|
||||
|
11
candle-kernels/src/ptx.rs
Normal file
11
candle-kernels/src/ptx.rs
Normal file
@ -0,0 +1,11 @@
|
||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.8.3"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -11,6 +11,7 @@ pub struct Cache {
|
||||
all_data: Option<Tensor>,
|
||||
dim: usize,
|
||||
current_seq_len: usize,
|
||||
grow_by: usize,
|
||||
max_seq_len: usize,
|
||||
}
|
||||
|
||||
@ -20,6 +21,7 @@ impl Cache {
|
||||
all_data: None,
|
||||
dim,
|
||||
current_seq_len: 0,
|
||||
grow_by: max_seq_len,
|
||||
max_seq_len,
|
||||
}
|
||||
}
|
||||
@ -65,11 +67,11 @@ impl Cache {
|
||||
};
|
||||
let ad = self.all_data.as_mut().unwrap();
|
||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||
candle::bail!(
|
||||
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
||||
self.current_seq_len,
|
||||
self.max_seq_len
|
||||
)
|
||||
let mut shape = src.dims().to_vec();
|
||||
shape[self.dim] = self.grow_by;
|
||||
let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
|
||||
*ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
|
||||
self.max_seq_len += self.grow_by;
|
||||
}
|
||||
ad.slice_set(src, self.dim, self.current_seq_len)?;
|
||||
self.current_seq_len += seq_len;
|
||||
|
@ -7,7 +7,7 @@ use candle::{Result, Tensor};
|
||||
/// Arguments
|
||||
///
|
||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories. This is expected to contain log probabilities.
|
||||
/// of categories. This is expected to contain log probabilities.
|
||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
||||
///
|
||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||
@ -34,7 +34,7 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
/// Arguments
|
||||
///
|
||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories. This is expected to raw logits.
|
||||
/// of categories. This is expected to raw logits.
|
||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
||||
///
|
||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||
@ -56,9 +56,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
/// Arguments
|
||||
///
|
||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories. This is expected to raw logits.
|
||||
/// of categories. This is expected to raw logits.
|
||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
|
||||
/// of categories.
|
||||
/// of categories.
|
||||
///
|
||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
|
@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid {
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
|
||||
};
|
||||
use candle::cuda_backend::SlicePtrOrNull;
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
@ -110,13 +110,17 @@ impl candle::CustomOp1 for Sigmoid {
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||
|
||||
let params = (el_count, dims.len(), &ds, src, &out);
|
||||
let mut builder = func.builder();
|
||||
candle::builder_arg!(builder, el_count, dims.len());
|
||||
ds.builder_arg(&mut builder);
|
||||
builder.arg(src);
|
||||
builder.arg(&out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@ -340,7 +344,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -367,12 +371,15 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
block_dim: (1, 32, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &dst, n_cols as i32);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, n_cols as i32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
@ -516,7 +523,7 @@ impl candle::CustomOp2 for RmsNorm {
|
||||
l2: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -552,19 +559,16 @@ impl candle::CustomOp2 for RmsNorm {
|
||||
block_dim: (block_size, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&dst,
|
||||
&alpha,
|
||||
n_cols as i32,
|
||||
block_size as i32,
|
||||
self.eps,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
builder.arg(&alpha);
|
||||
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
@ -751,7 +755,7 @@ impl candle::CustomOp3 for LayerNorm {
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -793,20 +797,18 @@ impl candle::CustomOp3 for LayerNorm {
|
||||
block_dim: (block_size, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
||||
let func =
|
||||
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&dst,
|
||||
&alpha,
|
||||
&beta,
|
||||
n_cols as i32,
|
||||
block_size as i32,
|
||||
self.eps,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
builder.arg(&alpha);
|
||||
builder.arg(&beta);
|
||||
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -117,12 +117,17 @@ impl candle::CustomOp3 for RotaryEmbI {
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
@ -333,7 +338,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -362,20 +367,17 @@ impl candle::CustomOp3 for RotaryEmb {
|
||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src,
|
||||
&cos,
|
||||
&sin,
|
||||
&dst,
|
||||
(b * h) as u32,
|
||||
(t * d) as u32,
|
||||
d as u32,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
@ -587,7 +589,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
l3: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||
};
|
||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
@ -616,14 +618,17 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||
let el = b * h * t * d;
|
||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||
let params = (
|
||||
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
|
||||
);
|
||||
let mut builder = func.builder();
|
||||
builder.arg(&src);
|
||||
builder.arg(&cos);
|
||||
builder.arg(&sin);
|
||||
builder.arg(&dst);
|
||||
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
|
@ -83,7 +83,7 @@ fn rms_norml(device: &Device) -> Result<()> {
|
||||
let (b_size, seq_len, head_dim) = (24, 70, 64);
|
||||
let el_count = b_size * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
|
||||
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
|
||||
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
|
||||
@ -130,7 +130,7 @@ fn layer_norml(device: &Device) -> Result<()> {
|
||||
let (b_size, seq_len, head_dim) = (24, 70, 64);
|
||||
let el_count = b_size * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
|
||||
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
|
||||
let beta = Tensor::zeros(head_dim, candle::DType::F32, device)?;
|
||||
@ -161,12 +161,12 @@ fn ropei(device: &Device) -> Result<()> {
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
let el_count = b_size * num_head * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||
@ -188,12 +188,12 @@ fn rope(device: &Device) -> Result<()> {
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
let el_count = b_size * num_head * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||
@ -215,12 +215,12 @@ fn rope_thd(device: &Device) -> Result<()> {
|
||||
let (b_size, num_head, seq_len, head_dim) = (2, 5, 10, 16);
|
||||
let el_count = b_size * num_head * seq_len * head_dim;
|
||||
let mut rng = StdRng::seed_from_u64(299792458);
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
|
||||
let src: Vec<f32> = (0..el_count).map(|_| rng.random::<f32>()).collect();
|
||||
let cos: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let sin: Vec<f32> = (0..seq_len * head_dim / 2)
|
||||
.map(|_| rng.gen::<f32>())
|
||||
.map(|_| rng.random::<f32>())
|
||||
.collect();
|
||||
let src = Tensor::from_vec(src, (b_size, num_head, seq_len, head_dim), device)?;
|
||||
let cos = Tensor::from_vec(cos, (seq_len, head_dim / 2), device)?;
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.8.3"
|
||||
version = "0.9.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.8.3" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.8.3" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.1" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,4 +1,5 @@
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
#![allow(clippy::useless_conversion)]
|
||||
use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
|
@ -4,7 +4,7 @@
|
||||
//! with support for temperature-based sampling, top-k filtering, nucleus sampling (top-p),
|
||||
//! and combinations thereof.
|
||||
use candle::{Context, DType, Error, Result, Tensor};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use rand::{distr::Distribution, SeedableRng};
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum Sampling {
|
||||
@ -50,7 +50,7 @@ impl LogitsProcessor {
|
||||
}
|
||||
|
||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||
let next_token = distr.sample(&mut self.rng) as u32;
|
||||
Ok(next_token)
|
||||
}
|
||||
|
501
candle-transformers/src/models/csm.rs
Normal file
501
candle-transformers/src/models/csm.rs
Normal file
@ -0,0 +1,501 @@
|
||||
//! Implementation of the Conversational Speech Model (CSM) from Sesame
|
||||
//!
|
||||
//! See: [CSM](Conversational Speech Model)
|
||||
//!
|
||||
/// CSM (Conversational Speech Model) is a speech generation model from Sesame that generates RVQ
|
||||
/// audio codes from text and audio inputs. The model architecture employs a Llama backbone and a
|
||||
/// smaller audio decoder that produces Mimi audio codes.
|
||||
///
|
||||
use crate::generation::LogitsProcessor;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Flavor {
|
||||
#[serde(rename = "llama-1B")]
|
||||
Llama1B,
|
||||
#[serde(rename = "llama-100M")]
|
||||
Llama100M,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub audio_num_codebooks: usize,
|
||||
pub audio_vocab_size: usize,
|
||||
pub backbone_flavor: Flavor,
|
||||
pub decoder_flavor: Flavor,
|
||||
pub text_vocab_size: usize,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LlamaConfig {
|
||||
vocab_size: usize,
|
||||
num_layers: usize,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
embed_dim: usize,
|
||||
max_seq_len: usize,
|
||||
intermediate_dim: usize,
|
||||
norm_eps: f64,
|
||||
rope_base: f32,
|
||||
scale_factor: usize,
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
pub fn from_flavor(flavor: Flavor) -> Self {
|
||||
match flavor {
|
||||
Flavor::Llama1B => Self {
|
||||
vocab_size: 128256,
|
||||
num_layers: 16,
|
||||
num_heads: 32,
|
||||
num_kv_heads: 8,
|
||||
embed_dim: 2048,
|
||||
max_seq_len: 2048,
|
||||
intermediate_dim: 8192,
|
||||
norm_eps: 1e-5,
|
||||
rope_base: 500_000.,
|
||||
scale_factor: 32,
|
||||
},
|
||||
Flavor::Llama100M => Self {
|
||||
vocab_size: 128256,
|
||||
num_layers: 4,
|
||||
num_heads: 8,
|
||||
num_kv_heads: 2,
|
||||
embed_dim: 1024,
|
||||
max_seq_len: 2048,
|
||||
intermediate_dim: 8192,
|
||||
norm_eps: 1e-5,
|
||||
rope_base: 500_000.,
|
||||
scale_factor: 32,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
fn calculate_default_inv_freq(cfg: &LlamaConfig) -> Vec<f32> {
|
||||
let head_dim = cfg.embed_dim / cfg.num_heads;
|
||||
(0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_base.powf(i as f32 / head_dim as f32))
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result<Self> {
|
||||
let low_freq_factor = 1.0;
|
||||
let high_freq_factor = 4.0;
|
||||
let original_max_position_embeddings = 8192;
|
||||
let scale_factor = cfg.scale_factor as f32;
|
||||
let theta = {
|
||||
let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
|
||||
let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
|
||||
|
||||
calculate_default_inv_freq(cfg)
|
||||
.into_iter()
|
||||
.map(|freq| {
|
||||
let wavelen = 2. * std::f32::consts::PI / freq;
|
||||
if wavelen < high_freq_wavelen {
|
||||
freq
|
||||
} else if wavelen > low_freq_wavelen {
|
||||
freq / scale_factor
|
||||
} else {
|
||||
let smooth = (original_max_position_embeddings as f32 / wavelen
|
||||
- low_freq_factor)
|
||||
/ (high_freq_factor - low_freq_factor);
|
||||
(1. - smooth) * freq / scale_factor + smooth * freq
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
let theta = Tensor::new(theta, dev)?;
|
||||
let idx_theta = Tensor::arange(0, cfg.max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((cfg.max_seq_len, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self { cos, sin })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
fn rms_norm(hidden_size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
|
||||
let weight = vb.get((hidden_size,), "scale")?;
|
||||
Ok(RmsNorm::new(weight, eps))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
|
||||
let head_dim = cfg.embed_dim / cfg.num_heads;
|
||||
let kv_dim = cfg.num_kv_heads * head_dim;
|
||||
|
||||
let q_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("output_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
rotary_emb,
|
||||
kv_cache: None,
|
||||
num_heads: cfg.num_heads,
|
||||
num_kv_heads: cfg.num_kv_heads,
|
||||
num_kv_groups: cfg.num_heads / cfg.num_kv_heads,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &self.kv_cache {
|
||||
None => (key_states, value_states),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||
(key_states, value_states)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||
|
||||
let attn_output = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, self.num_heads * self.head_dim))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
w1: Linear,
|
||||
w2: Linear,
|
||||
w3: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let w1 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w1"))?;
|
||||
let w2 = linear_b(cfg.intermediate_dim, cfg.embed_dim, false, vb.pp("w2"))?;
|
||||
let w3 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w3"))?;
|
||||
Ok(Self { w1, w2, w3 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.w1)?.silu()?;
|
||||
let rhs = xs.apply(&self.w3)?;
|
||||
(lhs * rhs)?.apply(&self.w2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Layer {
|
||||
mlp_norm: RmsNorm,
|
||||
sa_norm: RmsNorm,
|
||||
attn: Attention,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Layer {
|
||||
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
|
||||
let mlp_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("mlp_norm"))?;
|
||||
let sa_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("sa_norm"))?;
|
||||
let attn = Attention::new(cfg, rotary_emb, vb.pp("attn"))?;
|
||||
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
|
||||
Ok(Self {
|
||||
mlp_norm,
|
||||
sa_norm,
|
||||
attn,
|
||||
mlp,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.sa_norm.forward(xs)?;
|
||||
let xs = self.attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LlamaModel {
|
||||
layers: Vec<Layer>,
|
||||
norm: RmsNorm,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl LlamaModel {
|
||||
pub fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_layers);
|
||||
let vb_l = vb.pp("layers");
|
||||
for layer_idx in 0..cfg.num_layers {
|
||||
let layer = Layer::new(cfg, rotary_emb.clone(), vb_l.pp(layer_idx))?;
|
||||
layers.push(layer);
|
||||
}
|
||||
let norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("norm"))?;
|
||||
Ok(Self {
|
||||
layers,
|
||||
norm,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (_b_size, seq_len, _embed_dim) = xs.dims3()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?;
|
||||
}
|
||||
let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
backbone: LlamaModel,
|
||||
decoder: LlamaModel,
|
||||
codebook0_head: Linear,
|
||||
audio_embeddings: Embedding,
|
||||
text_embeddings: Embedding,
|
||||
projection: Linear,
|
||||
audio_head: Tensor,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let backbone_cfg = LlamaConfig::from_flavor(cfg.backbone_flavor);
|
||||
let backbone = LlamaModel::new(&backbone_cfg, vb.pp("backbone"))?;
|
||||
let decoder_cfg = LlamaConfig::from_flavor(cfg.decoder_flavor);
|
||||
let decoder = LlamaModel::new(&decoder_cfg, vb.pp("decoder"))?;
|
||||
let backbone_dim = backbone_cfg.embed_dim;
|
||||
let decoder_dim = decoder_cfg.embed_dim;
|
||||
let audio_embeddings = embedding(
|
||||
cfg.audio_vocab_size * cfg.audio_num_codebooks,
|
||||
backbone_dim,
|
||||
vb.pp("audio_embeddings"),
|
||||
)?;
|
||||
let text_embeddings =
|
||||
embedding(cfg.text_vocab_size, backbone_dim, vb.pp("text_embeddings"))?;
|
||||
let projection = linear_b(backbone_dim, decoder_dim, false, vb.pp("projection"))?;
|
||||
let codebook0_head = linear_b(
|
||||
backbone_dim,
|
||||
cfg.audio_vocab_size,
|
||||
false,
|
||||
vb.pp("codebook0_head"),
|
||||
)?;
|
||||
let audio_head = vb.get(
|
||||
(
|
||||
cfg.audio_num_codebooks - 1,
|
||||
decoder_dim,
|
||||
cfg.audio_vocab_size,
|
||||
),
|
||||
"audio_head",
|
||||
)?;
|
||||
Ok(Self {
|
||||
backbone,
|
||||
decoder,
|
||||
codebook0_head,
|
||||
audio_embeddings,
|
||||
text_embeddings,
|
||||
projection,
|
||||
audio_head,
|
||||
config: cfg.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
self.backbone.clear_kv_cache();
|
||||
self.decoder.clear_kv_cache();
|
||||
}
|
||||
|
||||
pub fn generate_frame(
|
||||
&mut self,
|
||||
tokens: &Tensor,
|
||||
tokens_mask: &Tensor,
|
||||
input_pos: usize,
|
||||
lp: &mut LogitsProcessor,
|
||||
) -> Result<Vec<u32>> {
|
||||
let (b_sz, seq_len, _cb_plus_one) = tokens.dims3()?;
|
||||
let audio_tokens = tokens.narrow(2, 0, self.config.audio_num_codebooks)?;
|
||||
let text_tokens = tokens.narrow(2, self.config.audio_num_codebooks, 1)?;
|
||||
let text_embeds = self.text_embeddings.forward(&text_tokens)?;
|
||||
let arange = (Tensor::arange(
|
||||
0u32,
|
||||
self.config.audio_num_codebooks as u32,
|
||||
&self.decoder.device,
|
||||
)? * self.config.audio_vocab_size as f64)?;
|
||||
let audio_tokens = audio_tokens.broadcast_add(&arange.reshape((1, 1, ()))?)?;
|
||||
let audio_embeds = self.audio_embeddings.forward(&audio_tokens)?.reshape((
|
||||
b_sz,
|
||||
seq_len,
|
||||
self.config.audio_num_codebooks,
|
||||
(),
|
||||
))?;
|
||||
let embeds = Tensor::cat(&[&audio_embeds, &text_embeds], D::Minus2)?;
|
||||
let embeds = embeds.broadcast_mul(
|
||||
&tokens_mask
|
||||
.to_dtype(self.backbone.dtype)?
|
||||
.unsqueeze(D::Minus1)?,
|
||||
)?;
|
||||
let embeds = embeds.sum(2)?;
|
||||
let h = self.backbone.forward(&embeds, input_pos)?;
|
||||
let c0_logits = h.apply(&self.codebook0_head)?;
|
||||
let c0_sample = lp.sample(&c0_logits.i((0, 0))?)?;
|
||||
let mut all_samples = vec![c0_sample];
|
||||
let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?;
|
||||
let c0_embed = self.audio_embeddings.forward(&c0_sample)?;
|
||||
let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?;
|
||||
|
||||
self.decoder.clear_kv_cache();
|
||||
let mut decoder_pos = 0;
|
||||
for i in 1..self.config.audio_num_codebooks {
|
||||
let proj_h = curr_h.apply(&self.projection)?;
|
||||
let decoder_h = self.decoder.forward(&proj_h, decoder_pos)?;
|
||||
decoder_pos += curr_h.dim(1)?;
|
||||
let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i - 1)?)?;
|
||||
let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?;
|
||||
all_samples.push(ci_sample);
|
||||
let ci_sample = Tensor::from_slice(
|
||||
&[ci_sample + (i * self.config.audio_vocab_size) as u32],
|
||||
(1, 1),
|
||||
&self.decoder.device,
|
||||
)?;
|
||||
let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
|
||||
curr_h = ci_embed
|
||||
}
|
||||
Ok(all_samples)
|
||||
}
|
||||
}
|
1051
candle-transformers/src/models/deepseek2.rs
Normal file
1051
candle-transformers/src/models/deepseek2.rs
Normal file
File diff suppressed because it is too large
Load Diff
483
candle-transformers/src/models/gemma3.rs
Normal file
483
candle-transformers/src/models/gemma3.rs
Normal file
@ -0,0 +1,483 @@
|
||||
//! Gemma LLM architecture (Google) inference implementation.
|
||||
//!
|
||||
//! See ["Introducing Gemma 3: The most capable model you can run on a single GPU or TPU"](https://blog.google/technology/developers/gemma-3/)
|
||||
//!
|
||||
//! Based on implementations from HuggingFace transformers.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b as linear, Activation, Linear, VarBuilder};
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub attention_bias: bool,
|
||||
pub head_dim: usize,
|
||||
pub hidden_activation: Activation,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub vocab_size: usize,
|
||||
pub final_logit_softcapping: Option<f64>,
|
||||
pub attn_logit_softcapping: Option<f64>,
|
||||
pub query_pre_attn_scalar: usize,
|
||||
pub sliding_window: usize,
|
||||
pub sliding_window_pattern: usize,
|
||||
pub max_position_embeddings: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RmsNorm {
|
||||
weight: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(dim: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(dim, "weight")?;
|
||||
Ok(Self { weight, eps })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for RmsNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let hidden_size = x.dim(D::Minus1)?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&(&self.weight + 1.0)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let intermediate_sz = cfg.intermediate_size;
|
||||
let gate_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?;
|
||||
let down_proj = linear(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_activation,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum KvCache {
|
||||
Normal(candle_nn::kv_cache::KvCache),
|
||||
Rotating(candle_nn::kv_cache::RotatingKvCache),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
q_norm: RmsNorm,
|
||||
k_norm: RmsNorm,
|
||||
num_heads: usize,
|
||||
num_kv_heads: usize,
|
||||
num_kv_groups: usize,
|
||||
head_dim: usize,
|
||||
attn_logit_softcapping: Option<f64>,
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
kv_cache: KvCache,
|
||||
use_flash_attn: bool,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let num_kv_heads = cfg.num_key_value_heads;
|
||||
let num_kv_groups = num_heads / num_kv_heads;
|
||||
let head_dim = cfg.head_dim;
|
||||
let bias = cfg.attention_bias;
|
||||
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||
let kv_cache = if is_sliding {
|
||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
|
||||
2,
|
||||
cfg.sliding_window,
|
||||
))
|
||||
} else {
|
||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
|
||||
};
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
q_norm,
|
||||
k_norm,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
num_kv_groups,
|
||||
head_dim,
|
||||
attn_logit_softcapping: cfg.attn_logit_softcapping,
|
||||
rotary_emb,
|
||||
kv_cache,
|
||||
use_flash_attn,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
let key_states = self.k_proj.forward(xs)?;
|
||||
let value_states = self.v_proj.forward(xs)?;
|
||||
|
||||
let query_states = query_states
|
||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let key_states = key_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let value_states = value_states
|
||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let query_states = self.q_norm.forward(&query_states)?;
|
||||
let key_states = self.k_norm.forward(&key_states)?;
|
||||
|
||||
let (query_states, key_states) =
|
||||
self.rotary_emb
|
||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||
|
||||
let (key_states, value_states) = match &mut self.kv_cache {
|
||||
KvCache::Normal(cache) => cache.append(&key_states, &value_states)?,
|
||||
KvCache::Rotating(cache) => cache.append(&key_states, &value_states)?,
|
||||
};
|
||||
|
||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||
let value_states =
|
||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||
|
||||
let attn_output = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = query_states.transpose(1, 2)?;
|
||||
let k = key_states.transpose(1, 2)?;
|
||||
let v = value_states.transpose(1, 2)?;
|
||||
let scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
|
||||
} else {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = match self.attn_logit_softcapping {
|
||||
None => attn_weights,
|
||||
Some(sc) => ((attn_weights / sc)?.tanh()? * sc)?,
|
||||
};
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&value_states)?
|
||||
};
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, q_len, ()))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
match &mut self.kv_cache {
|
||||
KvCache::Normal(c) => c.reset(),
|
||||
KvCache::Rotating(c) => c.reset(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderLayer {
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: RmsNorm,
|
||||
pre_feedforward_layernorm: RmsNorm,
|
||||
post_feedforward_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let self_attn = Attention::new(
|
||||
rotary_emb,
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
vb.pp("self_attn"),
|
||||
)?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
let input_layernorm =
|
||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let pre_feedforward_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("pre_feedforward_layernorm"),
|
||||
)?;
|
||||
let post_feedforward_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_feedforward_layernorm"),
|
||||
)?;
|
||||
let post_attention_layernorm = RmsNorm::new(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
pre_feedforward_layernorm,
|
||||
post_feedforward_layernorm,
|
||||
post_attention_layernorm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
let xs = xs.apply(&self.post_attention_layernorm)?;
|
||||
let xs = (xs + residual)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.pre_feedforward_layernorm)?;
|
||||
let xs = xs.apply(&self.mlp)?;
|
||||
let xs = xs.apply(&self.post_feedforward_layernorm)?;
|
||||
residual + xs
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.self_attn.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
layers: Vec<DecoderLayer>,
|
||||
norm: RmsNorm,
|
||||
lm_head: Linear,
|
||||
final_logit_softcapping: Option<f64>,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
hidden_size: usize,
|
||||
sliding_window: usize,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||
let layer = DecoderLayer::new(
|
||||
rotary_emb.clone(),
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
vb_l.pp(layer_idx),
|
||||
)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||
let lm_head = Linear::new(embed_tokens.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
embed_tokens,
|
||||
layers,
|
||||
norm,
|
||||
lm_head,
|
||||
final_logit_softcapping: cfg.final_logit_softcapping,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
hidden_size: cfg.hidden_size,
|
||||
sliding_window: cfg.sliding_window,
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = match Some(self.sliding_window) {
|
||||
None => (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect(),
|
||||
Some(sliding_window) => (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = self.embed_tokens.forward(input_ids)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)?;
|
||||
let logits = match self.final_logit_softcapping {
|
||||
None => logits,
|
||||
Some(sc) => ((logits / sc)?.tanh()? * sc)?,
|
||||
};
|
||||
|
||||
Ok(logits)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
}
|
@ -81,6 +81,126 @@ impl Config {
|
||||
vocab_size: 59514,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn opus_mt_en_zh() -> Self {
|
||||
Self {
|
||||
activation_function: candle_nn::Activation::Swish,
|
||||
d_model: 512,
|
||||
decoder_attention_heads: 8,
|
||||
decoder_ffn_dim: 2048,
|
||||
decoder_layers: 6,
|
||||
decoder_start_token_id: 65000,
|
||||
decoder_vocab_size: Some(65001),
|
||||
encoder_attention_heads: 8,
|
||||
encoder_ffn_dim: 2048,
|
||||
encoder_layers: 6,
|
||||
eos_token_id: 0,
|
||||
forced_eos_token_id: 0,
|
||||
is_encoder_decoder: true,
|
||||
max_position_embeddings: 512,
|
||||
pad_token_id: 65000,
|
||||
scale_embedding: true,
|
||||
share_encoder_decoder_embeddings: true,
|
||||
use_cache: true,
|
||||
vocab_size: 65001,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn opus_mt_en_hi() -> Self {
|
||||
Self {
|
||||
activation_function: candle_nn::Activation::Swish,
|
||||
d_model: 512,
|
||||
decoder_attention_heads: 8,
|
||||
decoder_ffn_dim: 2048,
|
||||
decoder_layers: 6,
|
||||
decoder_start_token_id: 61949,
|
||||
decoder_vocab_size: Some(61950),
|
||||
encoder_attention_heads: 8,
|
||||
encoder_ffn_dim: 2048,
|
||||
encoder_layers: 6,
|
||||
eos_token_id: 0,
|
||||
forced_eos_token_id: 0,
|
||||
is_encoder_decoder: true,
|
||||
max_position_embeddings: 512,
|
||||
pad_token_id: 61949,
|
||||
scale_embedding: true,
|
||||
share_encoder_decoder_embeddings: true,
|
||||
use_cache: true,
|
||||
vocab_size: 61950,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn opus_mt_en_es() -> Self {
|
||||
Self {
|
||||
activation_function: candle_nn::Activation::Swish,
|
||||
d_model: 512,
|
||||
decoder_attention_heads: 8,
|
||||
decoder_ffn_dim: 2048,
|
||||
decoder_layers: 6,
|
||||
decoder_start_token_id: 65000,
|
||||
decoder_vocab_size: Some(65001),
|
||||
encoder_attention_heads: 8,
|
||||
encoder_ffn_dim: 2048,
|
||||
encoder_layers: 6,
|
||||
eos_token_id: 0,
|
||||
forced_eos_token_id: 0,
|
||||
is_encoder_decoder: true,
|
||||
max_position_embeddings: 512,
|
||||
pad_token_id: 65000,
|
||||
scale_embedding: true,
|
||||
share_encoder_decoder_embeddings: true,
|
||||
use_cache: true,
|
||||
vocab_size: 65001,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn opus_mt_en_fr() -> Self {
|
||||
Self {
|
||||
activation_function: candle_nn::Activation::Swish,
|
||||
d_model: 512,
|
||||
decoder_attention_heads: 8,
|
||||
decoder_ffn_dim: 2048,
|
||||
decoder_layers: 6,
|
||||
decoder_start_token_id: 59513,
|
||||
decoder_vocab_size: Some(59514),
|
||||
encoder_attention_heads: 8,
|
||||
encoder_ffn_dim: 2048,
|
||||
encoder_layers: 6,
|
||||
eos_token_id: 0,
|
||||
forced_eos_token_id: 0,
|
||||
is_encoder_decoder: true,
|
||||
max_position_embeddings: 512,
|
||||
pad_token_id: 59513,
|
||||
scale_embedding: true,
|
||||
share_encoder_decoder_embeddings: true,
|
||||
use_cache: true,
|
||||
vocab_size: 59514,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn opus_mt_en_ru() -> Self {
|
||||
Self {
|
||||
activation_function: candle_nn::Activation::Swish,
|
||||
d_model: 512,
|
||||
decoder_attention_heads: 8,
|
||||
decoder_ffn_dim: 2048,
|
||||
decoder_layers: 6,
|
||||
decoder_start_token_id: 62517,
|
||||
decoder_vocab_size: Some(62518),
|
||||
encoder_attention_heads: 8,
|
||||
encoder_ffn_dim: 2048,
|
||||
encoder_layers: 6,
|
||||
eos_token_id: 0,
|
||||
forced_eos_token_id: 0,
|
||||
is_encoder_decoder: true,
|
||||
max_position_embeddings: 512,
|
||||
pad_token_id: 62517,
|
||||
scale_embedding: true,
|
||||
share_encoder_decoder_embeddings: true,
|
||||
use_cache: true,
|
||||
vocab_size: 62518,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -27,8 +27,10 @@ pub mod codegeex4_9b;
|
||||
pub mod colpali;
|
||||
pub mod convmixer;
|
||||
pub mod convnext;
|
||||
pub mod csm;
|
||||
pub mod dac;
|
||||
pub mod debertav2;
|
||||
pub mod deepseek2;
|
||||
pub mod depth_anything_v2;
|
||||
pub mod dinov2;
|
||||
pub mod dinov2reg4;
|
||||
@ -42,6 +44,7 @@ pub mod fastvit;
|
||||
pub mod flux;
|
||||
pub mod gemma;
|
||||
pub mod gemma2;
|
||||
pub mod gemma3;
|
||||
pub mod glm4;
|
||||
pub mod granite;
|
||||
pub mod helium;
|
||||
|
@ -6,14 +6,15 @@
|
||||
//! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code
|
||||
//!
|
||||
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm_no_bias, linear_no_bias, ops::softmax, Embedding, LayerNorm, Linear,
|
||||
Module, VarBuilder,
|
||||
embedding, layer_norm_no_bias, linear, linear_no_bias, ops::softmax, Embedding, LayerNorm,
|
||||
Linear, Module, VarBuilder,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use core::f32;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
@ -30,6 +31,24 @@ pub struct Config {
|
||||
pub global_rope_theta: f64,
|
||||
pub local_attention: usize,
|
||||
pub local_rope_theta: f64,
|
||||
#[serde(default)]
|
||||
#[serde(flatten)]
|
||||
pub classifier_config: Option<ClassifierConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq, Copy, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ClassifierPooling {
|
||||
#[default]
|
||||
CLS,
|
||||
MEAN,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct ClassifierConfig {
|
||||
pub id2label: HashMap<String, String>,
|
||||
pub label2id: HashMap<String, String>,
|
||||
pub classifier_pooling: ClassifierPooling,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -310,12 +329,11 @@ pub struct ModernBert {
|
||||
norm: LayerNorm,
|
||||
layers: Vec<ModernBertLayer>,
|
||||
final_norm: LayerNorm,
|
||||
head: ModernBertHead,
|
||||
local_attention_size: usize,
|
||||
}
|
||||
|
||||
impl ModernBert {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings = embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@ -359,19 +377,17 @@ impl ModernBert {
|
||||
config.layer_norm_eps,
|
||||
vb.pp("model.final_norm"),
|
||||
)?;
|
||||
let head = ModernBertHead::load(vb.pp("head"), config)?;
|
||||
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
norm,
|
||||
layers,
|
||||
final_norm,
|
||||
head,
|
||||
local_attention_size: config.local_attention,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
let seq_len = xs.shape().dims()[1];
|
||||
let global_attention_mask =
|
||||
prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?;
|
||||
@ -381,7 +397,7 @@ impl ModernBert {
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?;
|
||||
}
|
||||
let xs = xs.apply(&self.final_norm)?.apply(&self.head)?;
|
||||
let xs = xs.apply(&self.final_norm)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
@ -391,17 +407,98 @@ impl ModernBert {
|
||||
pub struct ModernBertForMaskedLM {
|
||||
model: ModernBert,
|
||||
decoder: ModernBertDecoder,
|
||||
head: ModernBertHead,
|
||||
}
|
||||
|
||||
impl ModernBertForMaskedLM {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let model = ModernBert::load(vb.clone(), config)?;
|
||||
let decoder = ModernBertDecoder::load(vb.clone(), config)?;
|
||||
Ok(Self { model, decoder })
|
||||
let head = ModernBertHead::load(vb.pp("head"), config)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
decoder,
|
||||
head,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?;
|
||||
let xs = self
|
||||
.model
|
||||
.forward(xs, mask)?
|
||||
.apply(&self.head)?
|
||||
.apply(&self.decoder)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBertClassifier {
|
||||
classifier: Linear,
|
||||
}
|
||||
|
||||
impl ModernBertClassifier {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
// The decoder weights are tied with the embeddings layer weights
|
||||
let classifier = linear(
|
||||
config.hidden_size,
|
||||
config
|
||||
.classifier_config
|
||||
.as_ref()
|
||||
.map(|cc| cc.id2label.len())
|
||||
.unwrap_or_default(),
|
||||
vb.pp("classifier"),
|
||||
)?;
|
||||
Ok(Self { classifier })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ModernBertClassifier {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.classifier)?;
|
||||
softmax(&xs, D::Minus1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ModernBertForSequenceClassification {
|
||||
model: ModernBert,
|
||||
head: ModernBertHead,
|
||||
classifier: ModernBertClassifier,
|
||||
classifier_pooling: ClassifierPooling,
|
||||
}
|
||||
|
||||
impl ModernBertForSequenceClassification {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let model = ModernBert::load(vb.clone(), config)?;
|
||||
let classifier = ModernBertClassifier::load(vb.clone(), config)?;
|
||||
let head = ModernBertHead::load(vb.pp("head"), config)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
head,
|
||||
classifier,
|
||||
classifier_pooling: config
|
||||
.classifier_config
|
||||
.as_ref()
|
||||
.map(|cc| cc.classifier_pooling)
|
||||
.unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result<Tensor> {
|
||||
let output = self.model.forward(xs, mask)?;
|
||||
let last_hidden_state = match self.classifier_pooling {
|
||||
ClassifierPooling::CLS => output.i((.., .., 0))?,
|
||||
ClassifierPooling::MEAN => {
|
||||
let unsqueezed_mask = &mask.unsqueeze(D::Minus1)?.to_dtype(DType::F32)?;
|
||||
let sum_output = output.broadcast_mul(unsqueezed_mask)?.sum(1)?;
|
||||
sum_output.broadcast_div(&mask.sum_keepdim(1)?.to_dtype(DType::F32)?)?
|
||||
}
|
||||
};
|
||||
let xs = self
|
||||
.head
|
||||
.forward(&last_hidden_state)?
|
||||
.apply(&self.classifier)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
@ -10,33 +10,133 @@ use crate::models::clip::div_l2_norm;
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, linear, LayerNorm, Linear, VarBuilder};
|
||||
|
||||
fn default_text_vocab_size() -> usize {
|
||||
32000
|
||||
}
|
||||
|
||||
fn default_text_hidden_size() -> usize {
|
||||
768
|
||||
}
|
||||
|
||||
fn default_text_intermediate_size() -> usize {
|
||||
3072
|
||||
}
|
||||
|
||||
fn default_text_num_hidden_layers() -> usize {
|
||||
12
|
||||
}
|
||||
|
||||
fn default_text_num_attention_heads() -> usize {
|
||||
12
|
||||
}
|
||||
|
||||
fn default_text_max_position_embeddings() -> usize {
|
||||
64
|
||||
}
|
||||
|
||||
fn default_text_layer_norm_eps() -> f64 {
|
||||
1e-6
|
||||
}
|
||||
|
||||
fn default_text_pad_token_id() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_text_bos_token_id() -> u32 {
|
||||
49406
|
||||
}
|
||||
|
||||
fn default_text_eos_token_id() -> u32 {
|
||||
49407
|
||||
}
|
||||
|
||||
fn default_text_hidden_act() -> candle_nn::Activation {
|
||||
candle_nn::Activation::GeluPytorchTanh
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L27
|
||||
#[derive(serde::Deserialize, Clone, Debug)]
|
||||
pub struct TextConfig {
|
||||
#[serde(default = "default_text_vocab_size")]
|
||||
pub vocab_size: usize,
|
||||
#[serde(default = "default_text_hidden_size")]
|
||||
pub hidden_size: usize,
|
||||
#[serde(default = "default_text_intermediate_size")]
|
||||
pub intermediate_size: usize,
|
||||
#[serde(default = "default_text_num_hidden_layers")]
|
||||
pub num_hidden_layers: usize,
|
||||
#[serde(default = "default_text_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
#[serde(default = "default_text_max_position_embeddings")]
|
||||
pub max_position_embeddings: usize,
|
||||
#[serde(default = "default_text_hidden_act")]
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
#[serde(default = "default_text_layer_norm_eps")]
|
||||
pub layer_norm_eps: f64,
|
||||
#[serde(default = "default_text_pad_token_id")]
|
||||
pub pad_token_id: u32,
|
||||
#[serde(default = "default_text_bos_token_id")]
|
||||
pub bos_token_id: u32,
|
||||
#[serde(default = "default_text_eos_token_id")]
|
||||
pub eos_token_id: u32,
|
||||
}
|
||||
|
||||
fn default_vision_hidden_size() -> usize {
|
||||
768
|
||||
}
|
||||
|
||||
fn default_vision_intermediate_size() -> usize {
|
||||
3072
|
||||
}
|
||||
|
||||
fn default_vision_num_hidden_layers() -> usize {
|
||||
12
|
||||
}
|
||||
|
||||
fn default_vision_num_attention_heads() -> usize {
|
||||
12
|
||||
}
|
||||
|
||||
fn default_vision_num_channels() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn default_vision_image_size() -> usize {
|
||||
224
|
||||
}
|
||||
|
||||
fn default_vision_batch_size() -> usize {
|
||||
16
|
||||
}
|
||||
|
||||
fn default_vision_layer_norm_eps() -> f64 {
|
||||
1e-6
|
||||
}
|
||||
|
||||
fn default_vision_hidden_act() -> candle_nn::Activation {
|
||||
candle_nn::Activation::GeluPytorchTanh
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/models/siglip/configuration_siglip.py#L132
|
||||
#[derive(serde::Deserialize, Clone, Debug)]
|
||||
pub struct VisionConfig {
|
||||
#[serde(default = "default_vision_hidden_size")]
|
||||
pub hidden_size: usize,
|
||||
#[serde(default = "default_vision_intermediate_size")]
|
||||
pub intermediate_size: usize,
|
||||
#[serde(default = "default_vision_num_hidden_layers")]
|
||||
pub num_hidden_layers: usize,
|
||||
#[serde(default = "default_vision_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
#[serde(default = "default_vision_num_channels")]
|
||||
pub num_channels: usize,
|
||||
#[serde(default = "default_vision_image_size")]
|
||||
pub image_size: usize,
|
||||
#[serde(default = "default_vision_batch_size")]
|
||||
pub patch_size: usize,
|
||||
#[serde(default = "default_vision_hidden_act")]
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
#[serde(default = "default_vision_layer_norm_eps")]
|
||||
pub layer_norm_eps: f64,
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,7 @@ use anyhow::Error as E;
|
||||
use candle::{safetensors::Load, DType, Device, IndexOp, Tensor, D};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
pub use candle_transformers::models::whisper::{self as m, Config};
|
||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||
use rand::{distr::Distribution, rngs::StdRng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokenizers::Tokenizer;
|
||||
use wasm_bindgen::prelude::*;
|
||||
@ -221,7 +221,7 @@ impl Decoder {
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = softmax(&(&logits / t)?, 0)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
let distr = rand::distr::weighted::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
|
Reference in New Issue
Block a user