mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
29 Commits
mkl_link_f
...
0.9.0-alph
Author | SHA1 | Date | |
---|---|---|---|
1d1d6d4fe6 | |||
2653002f29 | |||
a52b76ae82 | |||
fb660b8d43 | |||
2f9606b187 | |||
f3a73f80d1 | |||
b44d38de0e | |||
d9198deb37 | |||
15ed0b11ce | |||
34505fdf3a | |||
d7b7ce16e4 | |||
19fb6dac1f | |||
acc5bd335f | |||
eb478ece92 | |||
d339b01726 | |||
2f3bf42bcb | |||
e3370c6316 | |||
338f6a102e | |||
bc33df77e1 | |||
cf9d7bf24c | |||
9d31361c4f | |||
648596c073 | |||
d9904a3baf | |||
d6db305829 | |||
b4daa03e59 | |||
9541467d6b | |||
6429609090 | |||
ba473290da | |||
59c26195db |
40
.github/workflows/book-cd.yml
vendored
40
.github/workflows/book-cd.yml
vendored
@ -1,40 +0,0 @@
|
|||||||
name: Deploy Rust book
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- main
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
deploy:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write # To push a branch
|
|
||||||
pull-requests: write # To create a PR from that branch
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v3
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
- name: Install latest mdbook
|
|
||||||
run: |
|
|
||||||
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
|
||||||
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
|
||||||
mkdir mdbook
|
|
||||||
curl -sSL $url | tar -xz --directory=./mdbook
|
|
||||||
echo `pwd`/mdbook >> $GITHUB_PATH
|
|
||||||
- name: Deploy GitHub Pages
|
|
||||||
run: |
|
|
||||||
# This assumes your book is in the root of your repository.
|
|
||||||
# Just add a `cd` here if you need to change to another directory.
|
|
||||||
cd candle-book
|
|
||||||
mdbook build
|
|
||||||
git worktree add gh-pages
|
|
||||||
git config user.name "Deploy from CI"
|
|
||||||
git config user.email ""
|
|
||||||
cd gh-pages
|
|
||||||
# Delete the ref to avoid keeping history.
|
|
||||||
git update-ref -d refs/heads/gh-pages
|
|
||||||
rm -rf *
|
|
||||||
mv ../book/* .
|
|
||||||
git add .
|
|
||||||
git commit -m "Deploy $GITHUB_SHA to gh-pages"
|
|
||||||
git push --force --set-upstream origin gh-pages
|
|
29
.github/workflows/book.yml
vendored
29
.github/workflows/book.yml
vendored
@ -1,29 +0,0 @@
|
|||||||
name: CI
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
name: Test candle-book
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write # To push a branch
|
|
||||||
pull-requests: write # To create a PR from that branch
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@master
|
|
||||||
- name: Install Rust
|
|
||||||
run: |
|
|
||||||
rustup set profile minimal
|
|
||||||
rustup toolchain install stable
|
|
||||||
rustup default stable
|
|
||||||
- name: Install latest mdbook
|
|
||||||
run: |
|
|
||||||
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
|
||||||
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
|
||||||
mkdir bin
|
|
||||||
curl -sSL $url | tar -xz --directory=bin
|
|
||||||
echo "$(pwd)/bin" >> $GITHUB_PATH
|
|
||||||
- name: Run tests
|
|
||||||
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/
|
|
||||||
|
|
||||||
|
|
28
Cargo.toml
28
Cargo.toml
@ -3,7 +3,6 @@ members = [
|
|||||||
"candle-core",
|
"candle-core",
|
||||||
"candle-datasets",
|
"candle-datasets",
|
||||||
"candle-examples",
|
"candle-examples",
|
||||||
"candle-book",
|
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
@ -12,6 +11,7 @@ members = [
|
|||||||
"tensor-tools",
|
"tensor-tools",
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = [
|
||||||
|
"candle-book",
|
||||||
"candle-flash-attn",
|
"candle-flash-attn",
|
||||||
"candle-kernels",
|
"candle-kernels",
|
||||||
"candle-metal-kernels",
|
"candle-metal-kernels",
|
||||||
@ -20,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.8.4" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.8.4" }
|
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.3" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.8.4" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.3" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.8.4" }
|
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.3" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.8.4" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.3" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.8.4" }
|
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.3" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.8.4" }
|
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.3" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.8.4" }
|
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.3" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.13.5", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.4.1"
|
hf-hub = "0.4.1"
|
||||||
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
ug = "0.1.0"
|
ug = "0.3.1"
|
||||||
ug-cuda = "0.1.0"
|
ug-cuda = "0.3.1"
|
||||||
ug-metal = "0.1.0"
|
ug-metal = "0.3.1"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "1.1.1", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
@ -56,3 +56,7 @@ harness = false
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "metal_basics"
|
name = "metal_basics"
|
||||||
required-features = ["metal"]
|
required-features = ["metal"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "cuda_basics"
|
||||||
|
required-features = ["cuda"]
|
||||||
|
@ -21,7 +21,9 @@ impl BenchDevice for Device {
|
|||||||
Device::Cpu => Ok(()),
|
Device::Cpu => Ok(()),
|
||||||
Device::Cuda(device) => {
|
Device::Cuda(device) => {
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
return Ok(device.synchronize()?);
|
return Ok(device
|
||||||
|
.synchronize()
|
||||||
|
.map_err(|e| candle_core::Error::Cuda(Box::new(e)))?);
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(not(feature = "cuda"))]
|
||||||
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
panic!("Cuda device without cuda feature enabled: {:?}", device)
|
||||||
}
|
}
|
||||||
|
@ -6,28 +6,18 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
|
||||||
.to_dtype(candle_core::DType::BF16)?;
|
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
|
||||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||||
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
|
||||||
let _x1 = x.matmul(&x)?;
|
|
||||||
drop(_x1);
|
|
||||||
let start_time = std::time::Instant::now();
|
|
||||||
let _x1 = x.matmul(&x)?;
|
|
||||||
device.synchronize()?;
|
|
||||||
println!("fp32: {:?}", start_time.elapsed());
|
|
||||||
drop(_x1);
|
|
||||||
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
|
||||||
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
|
||||||
let _x1 = x.matmul(&x)?;
|
|
||||||
drop(_x1);
|
|
||||||
let start_time = std::time::Instant::now();
|
|
||||||
let _x1 = x.matmul(&x)?;
|
|
||||||
device.synchronize()?;
|
|
||||||
println!("tf32: {:?}", start_time.elapsed());
|
|
||||||
drop(_x1);
|
drop(_x1);
|
||||||
|
for _ in 0..20 {
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
|
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
||||||
|
device.synchronize()?;
|
||||||
|
println!("conv1d: {:?}", start_time.elapsed());
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ pub struct ParamsConv1D {
|
|||||||
pub(crate) padding: usize,
|
pub(crate) padding: usize,
|
||||||
pub(crate) stride: usize,
|
pub(crate) stride: usize,
|
||||||
pub(crate) dilation: usize,
|
pub(crate) dilation: usize,
|
||||||
|
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ParamsConv1D {
|
impl ParamsConv1D {
|
||||||
@ -54,7 +55,7 @@ impl ParamsConvTranspose1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum CudnnFwdAlgo {
|
pub enum CudnnFwdAlgo {
|
||||||
ImplicitGemm,
|
ImplicitGemm,
|
||||||
ImplicitPrecompGemm,
|
ImplicitPrecompGemm,
|
||||||
@ -151,6 +152,19 @@ impl Tensor {
|
|||||||
stride: usize,
|
stride: usize,
|
||||||
dilation: usize,
|
dilation: usize,
|
||||||
groups: usize,
|
groups: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a 1D convolution over the input tensor.
|
||||||
|
pub fn conv1d_with_algo(
|
||||||
|
&self,
|
||||||
|
kernel: &Self,
|
||||||
|
padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
groups: usize,
|
||||||
|
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||||
let (b_size, c_in, l_in) = self.dims3()?;
|
let (b_size, c_in, l_in) = self.dims3()?;
|
||||||
@ -174,6 +188,7 @@ impl Tensor {
|
|||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
|
cudnn_fwd_algo,
|
||||||
};
|
};
|
||||||
if groups == 1 {
|
if groups == 1 {
|
||||||
self.conv1d_single_group(kernel, ¶ms)
|
self.conv1d_single_group(kernel, ¶ms)
|
||||||
@ -278,6 +293,18 @@ impl Tensor {
|
|||||||
stride: usize,
|
stride: usize,
|
||||||
dilation: usize,
|
dilation: usize,
|
||||||
groups: usize,
|
groups: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conv2d_with_algo(
|
||||||
|
&self,
|
||||||
|
kernel: &Self,
|
||||||
|
padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
dilation: usize,
|
||||||
|
groups: usize,
|
||||||
|
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||||
@ -297,7 +324,7 @@ impl Tensor {
|
|||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
cudnn_fwd_algo: None,
|
cudnn_fwd_algo,
|
||||||
};
|
};
|
||||||
if groups == 1 {
|
if groups == 1 {
|
||||||
self.conv2d_single_group(kernel, ¶ms)
|
self.conv2d_single_group(kernel, ¶ms)
|
||||||
|
@ -1289,6 +1289,15 @@ impl Map2 for MatMul {
|
|||||||
} else {
|
} else {
|
||||||
Parallelism::None
|
Parallelism::None
|
||||||
};
|
};
|
||||||
|
let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
|
||||||
|
// a_skip and c_skip should be updated but step is always 0 so
|
||||||
|
// it wouldn't matter.
|
||||||
|
(1, b * m, n, k)
|
||||||
|
} else if a_skip == 0 && b_skip == n * k {
|
||||||
|
(1, m, b * n, k)
|
||||||
|
} else {
|
||||||
|
(b, m, n, k)
|
||||||
|
};
|
||||||
for step in 0..b {
|
for step in 0..b {
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
|
@ -43,7 +43,7 @@ pub(crate) fn launch_conv2d<
|
|||||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||||
return Ok(cudnn.clone());
|
return Ok(cudnn.clone());
|
||||||
}
|
}
|
||||||
let c = Cudnn::new(dev.cuda_device());
|
let c = Cudnn::new(dev.cuda_stream());
|
||||||
if let Ok(c) = &c {
|
if let Ok(c) = &c {
|
||||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||||
}
|
}
|
||||||
@ -109,7 +109,7 @@ pub(crate) fn launch_conv2d<
|
|||||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||||
};
|
};
|
||||||
let workspace_size = conv2d.get_workspace_size(alg)?;
|
let workspace_size = conv2d.get_workspace_size(alg)?;
|
||||||
let mut workspace = dev.cuda_device().alloc_zeros::<u8>(workspace_size)?;
|
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||||
unsafe {
|
unsafe {
|
||||||
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
conv2d.launch::<CudaSlice<u8>, _, _, _>(
|
||||||
alg,
|
alg,
|
||||||
@ -122,3 +122,104 @@ pub(crate) fn launch_conv2d<
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn launch_conv1d<
|
||||||
|
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
||||||
|
Y: cudarc::cudnn::CudnnDataType,
|
||||||
|
>(
|
||||||
|
src: &CudaView<T>,
|
||||||
|
src_l: &crate::Layout,
|
||||||
|
filter: &CudaView<T>,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
params: &crate::conv::ParamsConv1D,
|
||||||
|
dev: &crate::cuda_backend::CudaDevice,
|
||||||
|
) -> crate::Result<()> {
|
||||||
|
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
||||||
|
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
||||||
|
|
||||||
|
let device_id = dev.id();
|
||||||
|
let cudnn = CUDNN.with(|cudnn| {
|
||||||
|
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
||||||
|
return Ok(cudnn.clone());
|
||||||
|
}
|
||||||
|
let c = Cudnn::new(dev.cuda_stream());
|
||||||
|
if let Ok(c) = &c {
|
||||||
|
cudnn.borrow_mut().insert(device_id, c.clone());
|
||||||
|
}
|
||||||
|
c
|
||||||
|
})?;
|
||||||
|
let conv = cudnn.create_conv2d::<Y>(
|
||||||
|
/* pad */ [params.padding as i32, 0],
|
||||||
|
/* stride */ [params.stride as i32, 1],
|
||||||
|
/* dilation */ [params.dilation as i32, 1],
|
||||||
|
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
||||||
|
)?;
|
||||||
|
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
|
||||||
|
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
|
||||||
|
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
|
||||||
|
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
|
||||||
|
// > to 1.
|
||||||
|
let x_shape = [
|
||||||
|
params.b_size as i32,
|
||||||
|
params.c_in as i32,
|
||||||
|
params.l_in as i32,
|
||||||
|
1,
|
||||||
|
];
|
||||||
|
// Note that `src` already starts at the proper offset.
|
||||||
|
let x = if src_l.is_contiguous() {
|
||||||
|
cudnn.create_4d_tensor::<T>(
|
||||||
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
|
x_shape,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
let s = src_l.stride();
|
||||||
|
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
|
||||||
|
};
|
||||||
|
let w = cudnn.create_4d_filter::<T>(
|
||||||
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
|
[
|
||||||
|
params.c_out as i32,
|
||||||
|
params.c_in as i32,
|
||||||
|
params.k_size as i32,
|
||||||
|
1,
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
let l_out = params.l_out() as i32;
|
||||||
|
let y = cudnn.create_4d_tensor::<T>(
|
||||||
|
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
||||||
|
[params.b_size as i32, params.c_out as i32, l_out, 1],
|
||||||
|
)?;
|
||||||
|
let conv1d = ConvForward {
|
||||||
|
conv: &conv,
|
||||||
|
x: &x,
|
||||||
|
w: &w,
|
||||||
|
y: &y,
|
||||||
|
};
|
||||||
|
let alg = match params.cudnn_fwd_algo {
|
||||||
|
None => conv1d.pick_algorithm()?,
|
||||||
|
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
||||||
|
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
||||||
|
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
||||||
|
}
|
||||||
|
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
||||||
|
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
||||||
|
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
||||||
|
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
||||||
|
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
||||||
|
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
||||||
|
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
||||||
|
};
|
||||||
|
let workspace_size = conv1d.get_workspace_size(alg)?;
|
||||||
|
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
||||||
|
unsafe {
|
||||||
|
conv1d.launch::<CudaSlice<u8>, _, _, _>(
|
||||||
|
alg,
|
||||||
|
Some(&mut workspace),
|
||||||
|
(T::one(), T::zero()),
|
||||||
|
src,
|
||||||
|
filter,
|
||||||
|
dst,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -2,8 +2,9 @@ use crate::backend::BackendDevice;
|
|||||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||||
pub use candle_kernels as kernels;
|
pub use candle_kernels as kernels;
|
||||||
pub use cudarc;
|
pub use cudarc;
|
||||||
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||||
@ -24,10 +25,17 @@ impl DeviceId {
|
|||||||
struct CudaRng(cudarc::curand::CudaRng);
|
struct CudaRng(cudarc::curand::CudaRng);
|
||||||
unsafe impl Send for CudaRng {}
|
unsafe impl Send for CudaRng {}
|
||||||
|
|
||||||
|
pub struct ModuleStore {
|
||||||
|
mdls: [Option<Arc<cudarc::driver::CudaModule>>; kernels::ALL_IDS.len()],
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct CudaDevice {
|
pub struct CudaDevice {
|
||||||
id: DeviceId,
|
id: DeviceId,
|
||||||
device: Arc<cudarc::driver::CudaDevice>,
|
context: Arc<cudarc::driver::CudaContext>,
|
||||||
|
modules: Arc<std::sync::RwLock<ModuleStore>>,
|
||||||
|
custom_modules: Arc<std::sync::RwLock<HashMap<String, Arc<cudarc::driver::CudaModule>>>>,
|
||||||
|
stream: Arc<cudarc::driver::CudaStream>,
|
||||||
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
curand: Arc<Mutex<CudaRng>>,
|
curand: Arc<Mutex<CudaRng>>,
|
||||||
}
|
}
|
||||||
@ -38,17 +46,102 @@ impl std::fmt::Debug for CudaDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::ops::Deref for CudaDevice {
|
impl CudaDevice {
|
||||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
#[allow(clippy::missing_safety_doc)]
|
||||||
|
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
|
||||||
|
&self,
|
||||||
|
len: usize,
|
||||||
|
) -> Result<cudarc::driver::CudaSlice<T>> {
|
||||||
|
self.stream.alloc::<T>(len).w()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
len: usize,
|
||||||
|
) -> Result<cudarc::driver::CudaSlice<T>> {
|
||||||
|
self.stream.alloc_zeros::<T>(len).w()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpy_htod<
|
||||||
|
T: cudarc::driver::DeviceRepr,
|
||||||
|
Src: cudarc::driver::HostSlice<T> + ?Sized,
|
||||||
|
Dst: cudarc::driver::DevicePtrMut<T>,
|
||||||
|
>(
|
||||||
|
&self,
|
||||||
|
src: &Src,
|
||||||
|
dst: &mut Dst,
|
||||||
|
) -> Result<()> {
|
||||||
|
self.stream.memcpy_htod(src, dst).w()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpy_dtov<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
|
||||||
|
&self,
|
||||||
|
src: &Src,
|
||||||
|
) -> Result<Vec<T>> {
|
||||||
|
self.stream.memcpy_dtov(src).w()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpy_dtod<
|
||||||
|
T,
|
||||||
|
Src: cudarc::driver::DevicePtr<T>,
|
||||||
|
Dst: cudarc::driver::DevicePtrMut<T>,
|
||||||
|
>(
|
||||||
|
&self,
|
||||||
|
src: &Src,
|
||||||
|
dst: &mut Dst,
|
||||||
|
) -> Result<()> {
|
||||||
|
self.stream.memcpy_dtod(src, dst).w()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memcpy_stod<
|
||||||
|
T: cudarc::driver::DeviceRepr,
|
||||||
|
Src: cudarc::driver::HostSlice<T> + ?Sized,
|
||||||
|
>(
|
||||||
|
&self,
|
||||||
|
src: &Src,
|
||||||
|
) -> Result<cudarc::driver::CudaSlice<T>> {
|
||||||
|
self.stream.memcpy_stod(src).w()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CudaFunc {
|
||||||
|
func: CudaFunction,
|
||||||
|
stream: Arc<cudarc::driver::CudaStream>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for CudaFunc {
|
||||||
|
type Target = CudaFunction;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target {
|
||||||
&self.device
|
&self.func
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaFunc {
|
||||||
|
pub fn into_cuda_function(self) -> CudaFunction {
|
||||||
|
self.func
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! builder_arg {
|
||||||
|
($b:ident, $($arg:expr),*) => {
|
||||||
|
$(
|
||||||
|
let __arg = $arg;
|
||||||
|
$b.arg(&__arg);
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaFunc {
|
||||||
|
pub fn builder(&self) -> cudarc::driver::LaunchArgs<'_> {
|
||||||
|
self.stream.launch_builder(&self.func)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaDevice {
|
impl CudaDevice {
|
||||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
pub fn cuda_stream(&self) -> Arc<cudarc::driver::CudaStream> {
|
||||||
self.device.clone()
|
self.stream.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(target_arch = "wasm32"))]
|
#[cfg(not(target_arch = "wasm32"))]
|
||||||
@ -56,7 +149,7 @@ impl CudaDevice {
|
|||||||
&self,
|
&self,
|
||||||
func_name: &'static str,
|
func_name: &'static str,
|
||||||
kernel: ug::lang::ssa::Kernel,
|
kernel: ug::lang::ssa::Kernel,
|
||||||
) -> Result<CudaFunction> {
|
) -> Result<CudaFunc> {
|
||||||
let mut buf = vec![];
|
let mut buf = vec![];
|
||||||
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||||
let cuda_code = String::from_utf8(buf)?;
|
let cuda_code = String::from_utf8(buf)?;
|
||||||
@ -65,12 +158,12 @@ impl CudaDevice {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
||||||
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
|
let module = self.context.load_module(ptx).w()?;
|
||||||
let func = match self.device.get_func("ug", func_name) {
|
let func = module.load_function(func_name).w()?;
|
||||||
Some(func) => func,
|
Ok(CudaFunc {
|
||||||
None => crate::bail!("unknown function ug::{func_name}"),
|
func,
|
||||||
};
|
stream: self.stream.clone(),
|
||||||
Ok(func)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> DeviceId {
|
pub fn id(&self) -> DeviceId {
|
||||||
@ -83,58 +176,85 @@ impl CudaDevice {
|
|||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
DType::U8 => {
|
DType::U8 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<u8>(elem_count)? };
|
||||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||||
let params = (&data, v as u8, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as u8;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
DType::U32 => {
|
DType::U32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<u32>(elem_count)? };
|
||||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||||
let params = (&data, v as u32, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as u32;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
DType::I64 => {
|
DType::I64 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<i64>(elem_count)? };
|
||||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||||
let params = (&data, v as i64, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as i64;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<bf16>(elem_count)? };
|
||||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||||
let params = (&data, bf16::from_f64(v), elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = bf16::from_f64(v);
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<f16>(elem_count)? };
|
||||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||||
let params = (&data, f16::from_f64(v), elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = f16::from_f64(v);
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||||
let params = (&data, v as f32, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let v = v as f32;
|
||||||
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||||
let params = (&data, v, elem_count);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data);
|
||||||
|
builder.arg(&v);
|
||||||
|
builder.arg(&elem_count);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -144,38 +264,69 @@ impl CudaDevice {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
pub fn get_or_load_custom_func(
|
||||||
if !self.has_func(module_name, module_name) {
|
&self,
|
||||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
fn_name: &str,
|
||||||
// done once per kernel name.
|
module_name: &str,
|
||||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
ptx: &str,
|
||||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
) -> Result<CudaFunc> {
|
||||||
.map_err(|cuda| CudaError::Load {
|
let ms = self.custom_modules.read().unwrap();
|
||||||
cuda,
|
if let Some(mdl) = ms.get(module_name).as_ref() {
|
||||||
module_name: module_name.to_string(),
|
let func = mdl.load_function(fn_name).w()?;
|
||||||
})
|
return Ok(CudaFunc {
|
||||||
.w()?;
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
self.get_func(module_name, module_name)
|
drop(ms);
|
||||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
let mut ms = self.custom_modules.write().unwrap();
|
||||||
// able to only build the error value if needed.
|
let cuda_module = self.context.load_module(ptx.into()).w()?;
|
||||||
.ok_or(CudaError::MissingKernel {
|
ms.insert(module_name.to_string(), cuda_module.clone());
|
||||||
module_name: module_name.to_string(),
|
let func = cuda_module.load_function(fn_name).w()?;
|
||||||
})
|
Ok(CudaFunc {
|
||||||
.w()
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_or_load_func(&self, fn_name: &str, mdl: &kernels::Module) -> Result<CudaFunc> {
|
||||||
|
let ms = self.modules.read().unwrap();
|
||||||
|
if let Some(mdl) = ms.mdls[mdl.index()].as_ref() {
|
||||||
|
let func = mdl.load_function(fn_name).w()?;
|
||||||
|
return Ok(CudaFunc {
|
||||||
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
drop(ms);
|
||||||
|
let mut ms = self.modules.write().unwrap();
|
||||||
|
let cuda_module = self.context.load_module(mdl.ptx().into()).w()?;
|
||||||
|
ms.mdls[mdl.index()] = Some(cuda_module.clone());
|
||||||
|
let func = cuda_module.load_function(fn_name).w()?;
|
||||||
|
Ok(CudaFunc {
|
||||||
|
func,
|
||||||
|
stream: self.stream.clone(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaDevice {
|
impl CudaDevice {
|
||||||
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
pub fn new_with_stream(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?;
|
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
let stream = context.new_stream().w()?;
|
||||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||||
|
let module_store = ModuleStore {
|
||||||
|
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: DeviceId::new(),
|
id: DeviceId::new(),
|
||||||
device,
|
context,
|
||||||
|
stream,
|
||||||
blas: Arc::new(blas),
|
blas: Arc::new(blas),
|
||||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||||
|
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -184,14 +335,21 @@ impl BackendDevice for CudaDevice {
|
|||||||
type Storage = CudaStorage;
|
type Storage = CudaStorage;
|
||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
let context = cudarc::driver::CudaContext::new(ordinal).w()?;
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
let stream = context.default_stream();
|
||||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
let blas = cudarc::cublas::CudaBlas::new(stream.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, stream.clone()).w()?;
|
||||||
|
let module_store = ModuleStore {
|
||||||
|
mdls: [const { None }; kernels::ALL_IDS.len()],
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
id: DeviceId::new(),
|
id: DeviceId::new(),
|
||||||
device,
|
context,
|
||||||
|
stream,
|
||||||
blas: Arc::new(blas),
|
blas: Arc::new(blas),
|
||||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
modules: Arc::new(std::sync::RwLock::new(module_store)),
|
||||||
|
custom_modules: Arc::new(std::sync::RwLock::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,13 +357,13 @@ impl BackendDevice for CudaDevice {
|
|||||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||||
// state will be identical and the same random numbers will be generated.
|
// state will be identical and the same random numbers will be generated.
|
||||||
let mut curand = self.curand.lock().unwrap();
|
let mut curand = self.curand.lock().unwrap();
|
||||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.stream.clone()).w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
crate::DeviceLocation::Cuda {
|
crate::DeviceLocation::Cuda {
|
||||||
gpu_id: self.device.ordinal(),
|
gpu_id: self.context.ordinal(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -217,31 +375,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
DType::U8 => {
|
DType::U8 => {
|
||||||
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
let data = self.alloc_zeros::<u8>(elem_count)?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
DType::U32 => {
|
DType::U32 => {
|
||||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
let data = self.alloc_zeros::<u32>(elem_count)?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
DType::I64 => {
|
DType::I64 => {
|
||||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
let data = self.alloc_zeros::<i64>(elem_count)?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
let data = self.alloc_zeros::<bf16>(elem_count)?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
let data = self.alloc_zeros::<f16>(elem_count)?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
let data = self.alloc_zeros::<f32>(elem_count)?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
let data = self.alloc_zeros::<f64>(elem_count)?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -265,12 +423,12 @@ impl BackendDevice for CudaDevice {
|
|||||||
.w()?
|
.w()?
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
@ -309,7 +467,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
.w()?
|
.w()?
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
|
||||||
curand
|
curand
|
||||||
.0
|
.0
|
||||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||||
@ -317,7 +475,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
|
||||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
@ -336,31 +494,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
DType::U8 => {
|
DType::U8 => {
|
||||||
let data = self.alloc::<u8>(elem_count).w()?;
|
let data = self.alloc::<u8>(elem_count)?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
DType::U32 => {
|
DType::U32 => {
|
||||||
let data = self.alloc::<u32>(elem_count).w()?;
|
let data = self.alloc::<u32>(elem_count)?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
DType::I64 => {
|
DType::I64 => {
|
||||||
let data = self.alloc::<i64>(elem_count).w()?;
|
let data = self.alloc::<i64>(elem_count)?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
let data = self.alloc::<bf16>(elem_count).w()?;
|
let data = self.alloc::<bf16>(elem_count)?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
let data = self.alloc::<f16>(elem_count).w()?;
|
let data = self.alloc::<f16>(elem_count)?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = self.alloc::<f32>(elem_count).w()?;
|
let data = self.alloc::<f32>(elem_count)?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let data = self.alloc::<f64>(elem_count).w()?;
|
let data = self.alloc::<f64>(elem_count)?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -373,31 +531,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
let slice = match T::cpu_storage_ref(s) {
|
let slice = match T::cpu_storage_ref(s) {
|
||||||
CpuStorageRef::U8(storage) => {
|
CpuStorageRef::U8(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::U32(storage) => {
|
CpuStorageRef::U32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::I64(storage) => {
|
CpuStorageRef::I64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::BF16(storage) => {
|
CpuStorageRef::BF16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F16(storage) => {
|
CpuStorageRef::F16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F32(storage) => {
|
CpuStorageRef::F32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F64(storage) => {
|
CpuStorageRef::F64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -410,31 +568,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorage::U32(storage) => {
|
CpuStorage::U32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::I64(storage) => {
|
CpuStorage::I64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorage::BF16(storage) => {
|
CpuStorage::BF16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F16(storage) => {
|
CpuStorage::F16(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F64(storage) => {
|
CpuStorage::F64(storage) => {
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
let data = self.memcpy_stod(storage)?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -447,31 +605,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage)?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorage::U32(storage) => {
|
CpuStorage::U32(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage)?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::I64(storage) => {
|
CpuStorage::I64(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage)?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorage::BF16(storage) => {
|
CpuStorage::BF16(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage)?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F16(storage) => {
|
CpuStorage::F16(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage)?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage)?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F64(storage) => {
|
CpuStorage::F64(storage) => {
|
||||||
let data = self.htod_copy(storage).w()?;
|
let data = self.memcpy_stod(&storage)?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -482,7 +640,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn synchronize(&self) -> Result<()> {
|
fn synchronize(&self) -> Result<()> {
|
||||||
self.device.synchronize().map_err(crate::Error::wrap)?;
|
self.stream.synchronize().map_err(crate::Error::wrap)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -396,7 +396,10 @@ impl UgIOp1 {
|
|||||||
{
|
{
|
||||||
let device = device.as_cuda_device()?;
|
let device = device.as_cuda_device()?;
|
||||||
let func = device.compile(name, kernel)?;
|
let func = device.compile(name, kernel)?;
|
||||||
Ok(Self { name, func })
|
Ok(Self {
|
||||||
|
name,
|
||||||
|
func: func.into_cuda_function(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
{
|
{
|
||||||
@ -459,16 +462,16 @@ impl InplaceOp1 for UgIOp1 {
|
|||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
||||||
use crate::cuda_backend::WrapErr;
|
use crate::cuda_backend::WrapErr;
|
||||||
use cudarc::driver::LaunchAsync;
|
use cudarc::driver::PushKernelArg;
|
||||||
|
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
|
let stream = sto.device.cuda_stream();
|
||||||
// TODO: support more dtypes.
|
// TODO: support more dtypes.
|
||||||
let sto = sto.as_cuda_slice::<f32>()?;
|
let sto = sto.as_cuda_slice::<f32>()?;
|
||||||
let sto = match layout.contiguous_offsets() {
|
let sto = match layout.contiguous_offsets() {
|
||||||
None => crate::bail!("input has to be contiguous"),
|
None => crate::bail!("input has to be contiguous"),
|
||||||
Some((o1, o2)) => sto.slice(o1..o2),
|
Some((o1, o2)) => sto.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let params = (&sto,);
|
|
||||||
let (g, b) = if elem_count % 32 == 0 {
|
let (g, b) = if elem_count % 32 == 0 {
|
||||||
(elem_count / 32, 32)
|
(elem_count / 32, 32)
|
||||||
} else {
|
} else {
|
||||||
@ -479,7 +482,9 @@ impl InplaceOp1 for UgIOp1 {
|
|||||||
block_dim: (b as u32, 1, 1),
|
block_dim: (b as u32, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
unsafe { self.func.clone().launch(cfg, params) }.w()?;
|
let mut builder = stream.launch_builder(&self.func);
|
||||||
|
builder.arg(&sto);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -816,7 +816,7 @@ impl PthTensors {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `path` - Path to the pth file.
|
/// * `path` - Path to the pth file.
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||||
path: P,
|
path: P,
|
||||||
key: Option<&str>,
|
key: Option<&str>,
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
use super::{GgmlDType, QStorage};
|
||||||
use crate::quantized::k_quants::GgmlType;
|
use crate::quantized::k_quants::GgmlType;
|
||||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||||
use crate::{CudaDevice, CudaStorage, Result};
|
use crate::{builder_arg as barg, CudaDevice, CudaStorage, Result};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
|
||||||
use cudarc::driver::{CudaSlice, CudaView, DeviceSlice};
|
use cudarc::driver::{CudaSlice, CudaView, PushKernelArg};
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
struct PaddedCudaSlice {
|
struct PaddedCudaSlice {
|
||||||
@ -50,19 +50,20 @@ fn quantize_q8_1(
|
|||||||
ky: usize,
|
ky: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let kx = elem_count;
|
let kx = elem_count;
|
||||||
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
let kx_padded = pad(kx, MATRIX_ROW_PADDING);
|
||||||
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
let num_blocks = ceil_div(kx_padded, CUDA_QUANTIZE_BLOCK_SIZE);
|
||||||
let func = dev.get_or_load_func("quantize_q8_1", candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func("quantize_q8_1", &candle_kernels::QUANTIZED)?;
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (num_blocks as u32, ky as u32, 1),
|
grid_dim: (num_blocks as u32, ky as u32, 1),
|
||||||
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
block_dim: (CUDA_QUANTIZE_BLOCK_SIZE as u32, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let params = (src, dst, kx as i32, kx_padded as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(src);
|
||||||
|
builder.arg(dst);
|
||||||
|
barg!(builder, kx as i32, kx_padded as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -72,9 +73,7 @@ fn dequantize_f32(
|
|||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
let nb = elem_count.div_ceil(256);
|
||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||||
@ -99,8 +98,8 @@ fn dequantize_f32(
|
|||||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f32", true, 32, nb),
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
let dst = unsafe { dev.alloc::<f32>(elem_count)? };
|
||||||
// See e.g.
|
// See e.g.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
@ -110,15 +109,20 @@ fn dequantize_f32(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if is_k {
|
if is_k {
|
||||||
let params = (&data.inner, &dst);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
} else {
|
} else {
|
||||||
let nb32 = match dtype {
|
let nb32 = match dtype {
|
||||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
_ => elem_count / 32,
|
_ => elem_count / 32,
|
||||||
};
|
};
|
||||||
let params = (&data.inner, &dst, nb32 as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
barg!(builder, nb32 as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
}
|
}
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
@ -129,9 +133,7 @@ fn dequantize_f16(
|
|||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
let nb = elem_count.div_ceil(256);
|
||||||
|
|
||||||
let nb = (elem_count + 255) / 256;
|
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||||
@ -156,8 +158,8 @@ fn dequantize_f16(
|
|||||||
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
GgmlDType::Q8K => ("dequantize_block_q8_K_f16", true, 32, nb),
|
||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
|
||||||
// See e.g.
|
// See e.g.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
@ -167,15 +169,20 @@ fn dequantize_f16(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if is_k {
|
if is_k {
|
||||||
let params = (&data.inner, &dst);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
} else {
|
} else {
|
||||||
let nb32 = match dtype {
|
let nb32 = match dtype {
|
||||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||||
_ => elem_count / 32,
|
_ => elem_count / 32,
|
||||||
};
|
};
|
||||||
let params = (&data.inner, &dst, nb32 as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(&dst);
|
||||||
|
barg!(builder, nb32 as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
}
|
}
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
@ -188,8 +195,6 @@ fn dequantize_mul_mat_vec(
|
|||||||
nrows: usize,
|
nrows: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||||
if data_elems < ncols * nrows {
|
if data_elems < ncols * nrows {
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
@ -210,8 +215,8 @@ fn dequantize_mul_mat_vec(
|
|||||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
let dst = unsafe { dev.alloc::<f32>(nrows)? };
|
||||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (block_num_y as u32, 1, 1),
|
grid_dim: (block_num_y as u32, 1, 1),
|
||||||
@ -219,8 +224,12 @@ fn dequantize_mul_mat_vec(
|
|||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = (&data.inner, y, &dst, ncols as i32, nrows as i32);
|
let mut builder = func.builder();
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
builder.arg(&data.inner);
|
||||||
|
builder.arg(y);
|
||||||
|
builder.arg(&dst);
|
||||||
|
barg!(builder, ncols as i32, nrows as i32);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,8 +242,6 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
b_size: usize,
|
b_size: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||||
if data_elems < ncols * nrows {
|
if data_elems < ncols * nrows {
|
||||||
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
crate::bail!("unexpected data size {}, ncols {ncols} {nrows}", data_elems)
|
||||||
@ -249,7 +256,7 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||||
let y_size_in_bytes =
|
let y_size_in_bytes =
|
||||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
||||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||||
|
|
||||||
let kernel_name = match dtype {
|
let kernel_name = match dtype {
|
||||||
@ -266,13 +273,13 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let kernel_name = format!("{kernel_name}{b_size}");
|
let kernel_name = format!("{kernel_name}{b_size}");
|
||||||
let func = dev.get_or_load_func(&kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
let dst = unsafe { dev.alloc::<f32>(nrows * b_size)? };
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||||
let (nblocks, nwarps) = match b_size {
|
let (nblocks, nwarps) = match b_size {
|
||||||
1 => (nrows as u32, 4),
|
1 => (nrows as u32, 4),
|
||||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
2..=4 => ((nrows as u32).div_ceil(2), 4),
|
||||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
5..=8 => ((nrows as u32).div_ceil(2), 2),
|
||||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||||
};
|
};
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
@ -281,16 +288,18 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&data.inner,
|
builder.arg(&data.inner);
|
||||||
&y_q8_1,
|
builder.arg(&y_q8_1);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
|
barg!(
|
||||||
|
builder,
|
||||||
/* ncols_x */ ncols as i32,
|
/* ncols_x */ ncols as i32,
|
||||||
/* nrows_x */ nrows as i32,
|
/* nrows_x */ nrows as i32,
|
||||||
/* nrows_y */ ncols_padded as i32,
|
/* nrows_y */ ncols_padded as i32,
|
||||||
/* nrows_dst */ nrows as i32,
|
/* nrows_dst */ nrows as i32
|
||||||
);
|
);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -305,8 +314,6 @@ fn mul_mat_via_q8_1(
|
|||||||
y_cols: usize,
|
y_cols: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
use cudarc::driver::LaunchAsync;
|
|
||||||
|
|
||||||
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
let data_elems = data.len / dtype.type_size() * dtype.block_size();
|
||||||
if data_elems < x_rows * x_cols {
|
if data_elems < x_rows * x_cols {
|
||||||
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
crate::bail!("unexpected lhs size {}, {x_rows} {x_cols}", data_elems)
|
||||||
@ -322,7 +329,7 @@ fn mul_mat_via_q8_1(
|
|||||||
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
||||||
let y_size_in_bytes =
|
let y_size_in_bytes =
|
||||||
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
||||||
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
||||||
|
|
||||||
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
||||||
@ -338,8 +345,8 @@ fn mul_mat_via_q8_1(
|
|||||||
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
GgmlDType::Q6K => ("mul_mat_q6_K", 64, 64),
|
||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols)? };
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (
|
grid_dim: (
|
||||||
ceil_div(x_rows, mmq_y) as u32,
|
ceil_div(x_rows, mmq_y) as u32,
|
||||||
@ -350,17 +357,19 @@ fn mul_mat_via_q8_1(
|
|||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
/* vx */ &data.inner,
|
builder.arg(/* vx */ &data.inner);
|
||||||
/* vy */ &y_q8_1,
|
builder.arg(/* vy */ &y_q8_1);
|
||||||
/* dst */ &dst,
|
builder.arg(/* dst */ &dst);
|
||||||
|
barg!(
|
||||||
|
builder,
|
||||||
/* ncols_x */ x_cols as i32,
|
/* ncols_x */ x_cols as i32,
|
||||||
/* nrows_x */ x_rows as i32,
|
/* nrows_x */ x_rows as i32,
|
||||||
/* ncols_y */ y_cols as i32,
|
/* ncols_y */ y_cols as i32,
|
||||||
/* nrows_y */ k_padded as i32,
|
/* nrows_y */ k_padded as i32,
|
||||||
/* nrows_dst */ x_rows as i32,
|
/* nrows_dst */ x_rows as i32
|
||||||
);
|
);
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -369,7 +378,7 @@ impl QCudaStorage {
|
|||||||
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||||
let padded_size_in_bytes =
|
let padded_size_in_bytes =
|
||||||
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
|
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
|
||||||
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
|
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes)?;
|
||||||
Ok(QCudaStorage {
|
Ok(QCudaStorage {
|
||||||
data: PaddedCudaSlice {
|
data: PaddedCudaSlice {
|
||||||
inner,
|
inner,
|
||||||
@ -416,8 +425,7 @@ impl QCudaStorage {
|
|||||||
|
|
||||||
let buffer = self
|
let buffer = self
|
||||||
.device
|
.device
|
||||||
.dtoh_sync_copy(&self.data.inner.slice(..self.data.len))
|
.memcpy_dtov(&self.data.inner.slice(..self.data.len))?;
|
||||||
.w()?;
|
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
let block_len = elem_count / self.dtype.block_size();
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
@ -448,9 +456,7 @@ impl QCudaStorage {
|
|||||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||||
// Run the quantization on cpu.
|
// Run the quantization on cpu.
|
||||||
let src = match &src.slice {
|
let src = match &src.slice {
|
||||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
|
||||||
self.device.dtoh_sync_copy(data).w()?
|
|
||||||
}
|
|
||||||
_ => crate::bail!("only f32 can be quantized"),
|
_ => crate::bail!("only f32 can be quantized"),
|
||||||
};
|
};
|
||||||
let src_len = src.len();
|
let src_len = src.len();
|
||||||
@ -460,10 +466,9 @@ impl QCudaStorage {
|
|||||||
let data = qcpu_storage.data()?;
|
let data = qcpu_storage.data()?;
|
||||||
let padded_len =
|
let padded_len =
|
||||||
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
||||||
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
|
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
|
||||||
self.device
|
self.device
|
||||||
.htod_sync_copy_into(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
|
||||||
.w()?;
|
|
||||||
self.data = PaddedCudaSlice {
|
self.data = PaddedCudaSlice {
|
||||||
inner,
|
inner,
|
||||||
len: data.len(),
|
len: data.len(),
|
||||||
@ -597,10 +602,8 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
};
|
};
|
||||||
let dtype = T::DTYPE;
|
let dtype = T::DTYPE;
|
||||||
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
||||||
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
|
let mut inner = unsafe { device.alloc::<u8>(padded_len)? };
|
||||||
device
|
device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?;
|
||||||
.htod_sync_copy_into(data, &mut inner.slice_mut(..data.len()))
|
|
||||||
.w()?;
|
|
||||||
Ok(QStorage::Cuda(QCudaStorage {
|
Ok(QStorage::Cuda(QCudaStorage {
|
||||||
data: PaddedCudaSlice {
|
data: PaddedCudaSlice {
|
||||||
inner,
|
inner,
|
||||||
@ -622,9 +625,9 @@ mod test {
|
|||||||
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
||||||
let y_size_in_bytes =
|
let y_size_in_bytes =
|
||||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
||||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs)?;
|
||||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -634,7 +637,7 @@ mod test {
|
|||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let ncols = 256;
|
let ncols = 256;
|
||||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs)?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||||
@ -647,7 +650,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||||
assert_eq!(vs.len(), 1);
|
assert_eq!(vs.len(), 1);
|
||||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||||
// Q8 means 1/256 precision.
|
// Q8 means 1/256 precision.
|
||||||
@ -662,7 +665,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||||
assert_eq!(vs.len(), 1);
|
assert_eq!(vs.len(), 1);
|
||||||
assert_eq!(vs[0], 5561851.0);
|
assert_eq!(vs[0], 5561851.0);
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -673,7 +676,7 @@ mod test {
|
|||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let ncols = 256;
|
let ncols = 256;
|
||||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs)?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
@ -687,7 +690,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||||
@ -714,7 +717,7 @@ mod test {
|
|||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
||||||
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
||||||
let y = dev.htod_sync_copy(&vs).w()?;
|
let y = dev.memcpy_stod(&vs)?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
@ -728,7 +731,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap();
|
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ impl ArgSort {
|
|||||||
mod cuda {
|
mod cuda {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::cuda_backend::cudarc::driver::{
|
use crate::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
CudaSlice, DeviceRepr, LaunchConfig, ValidAsZeroBits,
|
||||||
};
|
};
|
||||||
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
use crate::cuda_backend::{kernel_name, kernels, CudaStorageSlice as S, WrapErr};
|
||||||
use crate::{CudaDevice, WithDType};
|
use crate::{CudaDevice, WithDType};
|
||||||
@ -69,27 +69,33 @@ mod cuda {
|
|||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
_wrap: W,
|
_wrap: W,
|
||||||
) -> Result<S> {
|
) -> Result<S> {
|
||||||
|
use cudarc::driver::PushKernelArg;
|
||||||
|
|
||||||
let slice = match layout.contiguous_offsets() {
|
let slice = match layout.contiguous_offsets() {
|
||||||
None => crate::bail!("input has to be contiguous"),
|
None => crate::bail!("input has to be contiguous"),
|
||||||
Some((o1, o2)) => src.slice(o1..o2),
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<u32>(elem_count)? };
|
||||||
let func = if self.asc {
|
let func = if self.asc {
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), kernels::SORT)?
|
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
||||||
} else {
|
} else {
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), kernels::SORT)?
|
dev.get_or_load_func(&kernel_name::<T>("asort_desc"), &kernels::SORT)?
|
||||||
};
|
};
|
||||||
let ncols = self.last_dim;
|
let ncols = self.last_dim;
|
||||||
let nrows = elem_count / ncols;
|
let nrows = elem_count / ncols;
|
||||||
let ncols_pad = next_power_of_2(ncols);
|
let ncols_pad = next_power_of_2(ncols);
|
||||||
let params = (&slice, &dst, ncols as i32, ncols_pad as i32);
|
|
||||||
let cfg = LaunchConfig {
|
let cfg = LaunchConfig {
|
||||||
grid_dim: (1, nrows as u32, 1),
|
grid_dim: (1, nrows as u32, 1),
|
||||||
block_dim: (ncols_pad as u32, 1, 1),
|
block_dim: (ncols_pad as u32, 1, 1),
|
||||||
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
|
||||||
};
|
};
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let stream = dev.cuda_stream();
|
||||||
|
let mut builder = stream.launch_builder(&func);
|
||||||
|
let ncols = ncols as i32;
|
||||||
|
let ncols_pad = ncols_pad as i32;
|
||||||
|
builder.arg(&slice).arg(&dst).arg(&ncols).arg(&ncols_pad);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(S::U32(dst))
|
Ok(S::U32(dst))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2580,6 +2580,28 @@ impl Tensor {
|
|||||||
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||||
rhs.broadcast_mul(&self.log()?)?.exp()
|
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a new tensor with the order of elements reversed along the specified dimensions.
|
||||||
|
/// This function makes a copy of the tensor’s data.
|
||||||
|
///
|
||||||
|
/// ```rust
|
||||||
|
/// # use candle_core::{Tensor, Device};
|
||||||
|
/// let t = Tensor::arange(0., 6., &Device::Cpu)?.reshape((2, 3))?;
|
||||||
|
/// assert_eq!(t.to_vec2::<f64>()?, &[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
|
/// let t_flipped = t.flip(&[0])?;
|
||||||
|
/// assert_eq!(t_flipped.to_vec2::<f64>()?, &[[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn flip(&self, dims: &[usize]) -> Result<Tensor> {
|
||||||
|
let mut result = self.clone();
|
||||||
|
for &dim in dims.iter() {
|
||||||
|
let size = result.dim(dim)?;
|
||||||
|
let indices: Vec<i64> = (0..size).rev().map(|x| x as i64).collect();
|
||||||
|
let indices_tensor = Tensor::from_vec(indices, (size,), result.device())?;
|
||||||
|
result = result.index_select(&indices_tensor, dim)?;
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
@ -24,6 +24,15 @@ macro_rules! test_device {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
|
||||||
|
assert_eq!(t1.shape(), t2.shape());
|
||||||
|
// Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
|
||||||
|
let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
|
||||||
|
let all_equal = eq_tensor.sum_all()?;
|
||||||
|
assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
|
||||||
let b = 10f32.powi(digits);
|
let b = 10f32.powi(digits);
|
||||||
let t = t.to_vec0::<f32>()?;
|
let t = t.to_vec0::<f32>()?;
|
||||||
|
@ -53,6 +53,20 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
|
let res = {
|
||||||
|
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||||
|
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
|
||||||
|
};
|
||||||
|
assert_eq!(res.dims(), [3, 2, 5]);
|
||||||
|
// Same as pytorch default padding: use zeros.
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||||
|
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||||
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
|
);
|
||||||
|
|
||||||
let w = w.transpose(0, 1)?;
|
let w = w.transpose(0, 1)?;
|
||||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||||
@ -163,6 +177,22 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
let res = {
|
||||||
|
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
||||||
|
t.conv2d(&w, 0, 1, 1, 1)?
|
||||||
|
};
|
||||||
|
assert_eq!(res.dims(), [3, 2, 3, 3]);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
||||||
|
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
||||||
|
[
|
||||||
|
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
|
||||||
|
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#![allow(clippy::approx_constant)]
|
#![allow(clippy::approx_constant)]
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
use candle_core::{test_device, test_utils, DType, Device, Shape, Tensor, Var};
|
||||||
|
|
||||||
fn simple_grad(device: &Device) -> Result<()> {
|
fn simple_grad(device: &Device) -> Result<()> {
|
||||||
let x = Var::new(&[3f32, 1., 4.], device)?;
|
let x = Var::new(&[3f32, 1., 4.], device)?;
|
||||||
@ -505,6 +505,36 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_backprop() -> Result<()> {
|
||||||
|
let device = &Device::Cpu;
|
||||||
|
|
||||||
|
// Create a tensor (leaf node) that requires gradients
|
||||||
|
let x = Var::ones((2, 2), DType::F64, device)?;
|
||||||
|
let weights = Tensor::arange(1.0, 5.0, device)?.reshape((2, 2))?;
|
||||||
|
|
||||||
|
let y = x.matmul(&weights)?;
|
||||||
|
let expected_y = Tensor::from_vec(vec![4.0, 6.0, 4.0, 6.0], (2, 2), device)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&y, &expected_y)?;
|
||||||
|
|
||||||
|
let z = y.flip(&[1])?;
|
||||||
|
let expected_z = Tensor::from_vec(vec![6.0, 4.0, 6.0, 4.0], (2, 2), device)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&z, &expected_z)?;
|
||||||
|
|
||||||
|
let loss = z.sum_all()?;
|
||||||
|
|
||||||
|
let grad_store = loss.backward()?;
|
||||||
|
let grad_x = grad_store.get_id(x.id()).unwrap();
|
||||||
|
|
||||||
|
let flipped_weights = weights.flip(&[1])?;
|
||||||
|
let dloss_dy = Tensor::ones((2, 2), DType::F64, device)?;
|
||||||
|
// dloss/dx = dloss/dy @ dy/dx = ones @ weight.flip.T
|
||||||
|
let expected_grad = dloss_dy.matmul(&flipped_weights.t()?)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(grad_x, &expected_grad)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(
|
test_device!(
|
||||||
simple_grad,
|
simple_grad,
|
||||||
simple_grad_cpu,
|
simple_grad_cpu,
|
||||||
|
@ -1682,3 +1682,54 @@ fn pow() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_1d() -> Result<()> {
|
||||||
|
// 1D: [0, 1, 2, 3, 4]
|
||||||
|
let t = Tensor::arange(0.0, 5.0, &Device::Cpu)?.reshape((5,))?;
|
||||||
|
let flipped = t.flip(&[0])?;
|
||||||
|
// Expected: [4, 3, 2, 1, 0]
|
||||||
|
let expected = Tensor::from_vec(vec![4.0, 3.0, 2.0, 1.0, 0.0], (5,), &Device::Cpu)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_2d() -> Result<()> {
|
||||||
|
// 2D:
|
||||||
|
// [[0, 1, 2],
|
||||||
|
// [3, 4, 5]]
|
||||||
|
let t = Tensor::arange(0.0, 6.0, &Device::Cpu)?.reshape((2, 3))?;
|
||||||
|
let flipped = t.flip(&[0, 1])?;
|
||||||
|
// Expected:
|
||||||
|
// [[5, 4, 3],
|
||||||
|
// [2, 1, 0]]
|
||||||
|
let expected = Tensor::from_vec(vec![5.0, 4.0, 3.0, 2.0, 1.0, 0.0], (2, 3), &Device::Cpu)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flip_3d_channels() -> Result<()> {
|
||||||
|
// 3D:
|
||||||
|
// [[[0,1,2],
|
||||||
|
// [3,4,5]],
|
||||||
|
//
|
||||||
|
// [[6,7,8],
|
||||||
|
// [9,10,11]]]
|
||||||
|
let t = Tensor::arange(0.0, 12.0, &Device::Cpu)?.reshape((2, 2, 3))?;
|
||||||
|
let flipped = t.flip(&[2])?;
|
||||||
|
// Expected:
|
||||||
|
// [[[2,1,0],
|
||||||
|
// [5,4,3]],
|
||||||
|
//
|
||||||
|
// [[8,7,6],
|
||||||
|
// [11,10,9]]]
|
||||||
|
let expected = Tensor::from_vec(
|
||||||
|
vec![2.0, 1.0, 0.0, 5.0, 4.0, 3.0, 8.0, 7.0, 6.0, 11.0, 10.0, 9.0],
|
||||||
|
(2, 2, 3),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&flipped, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -72,6 +72,8 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
|||||||
if let parquet::record::Field::Group(subrow) = field {
|
if let parquet::record::Field::Group(subrow) = field {
|
||||||
for (_name, field) in subrow.get_column_iter() {
|
for (_name, field) in subrow.get_column_iter() {
|
||||||
if let parquet::record::Field::Bytes(value) = field {
|
if let parquet::record::Field::Bytes(value) = field {
|
||||||
|
// image-rs crate convention is to load in (width, height, channels) order
|
||||||
|
// See: https://docs.rs/image/latest/image/trait.ImageDecoder.html#tymethod.dimensions
|
||||||
let image = image::load_from_memory(value.data()).unwrap();
|
let image = image::load_from_memory(value.data()).unwrap();
|
||||||
buffer_images.extend(image.to_rgb8().as_raw());
|
buffer_images.extend(image.to_rgb8().as_raw());
|
||||||
}
|
}
|
||||||
@ -81,8 +83,10 @@ fn load_parquet(parquet: SerializedFileReader<std::fs::File>) -> Result<(Tensor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let images = (Tensor::from_vec(buffer_images, (samples, 3, 32, 32), &Device::Cpu)?
|
// Reorder image-rs convention (width, height, channels) to candle/pytorch convolution convention (channels, height, width)
|
||||||
.to_dtype(DType::U8)?
|
let images = (Tensor::from_vec(buffer_images, (samples, 32, 32, 3), &Device::Cpu)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.permute((0, 3, 2, 1))?
|
||||||
/ 255.)?;
|
/ 255.)?;
|
||||||
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
let labels = Tensor::from_vec(buffer_labels, (samples,), &Device::Cpu)?;
|
||||||
Ok((images, labels))
|
Ok((images, labels))
|
||||||
|
@ -60,7 +60,7 @@ bindgen_cuda = { version = "0.1.1", optional = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
||||||
cudnn = ["candle/cudnn"]
|
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
|
||||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
@ -69,6 +69,7 @@ metal = ["candle/metal", "candle-nn/metal"]
|
|||||||
microphone = ["cpal", "rubato"]
|
microphone = ["cpal", "rubato"]
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
mimi = ["cpal", "symphonia", "rubato"]
|
mimi = ["cpal", "symphonia", "rubato"]
|
||||||
|
snac = ["cpal", "symphonia", "rubato"]
|
||||||
depth_anything_v2 = ["palette", "enterpolation"]
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
@ -107,6 +108,10 @@ required-features = ["candle-datasets"]
|
|||||||
name = "mimi"
|
name = "mimi"
|
||||||
required-features = ["mimi"]
|
required-features = ["mimi"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "snac"
|
||||||
|
required-features = ["snac"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "encodec"
|
name = "encodec"
|
||||||
required-features = ["encodec"]
|
required-features = ["encodec"]
|
||||||
|
13
candle-examples/examples/chatglm/README.md
Normal file
13
candle-examples/examples/chatglm/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# candle-chatglm
|
||||||
|
|
||||||
|
Uses `THUDM/chatglm3-6b` to generate chinese text. Will not generate text for english (usually).
|
||||||
|
|
||||||
|
## Text Generation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example chatglm --release -- --prompt "部署门槛较低等众多优秀特 "
|
||||||
|
|
||||||
|
> 部署门槛较低等众多优秀特 点,使得其成为了一款备受欢迎的AI助手。
|
||||||
|
>
|
||||||
|
> 作为一款人工智能助手,ChatGLM3-6B
|
||||||
|
```
|
42
candle-examples/examples/chinese_clip/README.md
Normal file
42
candle-examples/examples/chinese_clip/README.md
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# candle-chinese-clip
|
||||||
|
|
||||||
|
Contrastive Language-Image Pre-Training (CLIP) is an architecture trained on
|
||||||
|
pairs of images with related texts. This one is trained using in chinese instead of english.
|
||||||
|
|
||||||
|
## Running on cpu
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
|
||||||
|
|
||||||
|
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
>
|
||||||
|
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
|
||||||
|
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||||
|
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
|
||||||
|
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
|
||||||
|
>
|
||||||
|
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
>
|
||||||
|
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
|
||||||
|
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||||
|
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running on metal
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --features metal --example chinese_clip --release -- --images "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg","candle-examples/examples/yolo-v8/assets/bike.jpg" --cpu --sequences "一场自行车比赛","两只猫的照片","一个机器人拿着蜡烛"
|
||||||
|
|
||||||
|
> Results for image: candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg
|
||||||
|
>
|
||||||
|
> 2025-03-25T19:22:01.325177Z INFO chinese_clip: Probability: 0.0000% Text: 一场自行车比赛
|
||||||
|
> 2025-03-25T19:22:01.325179Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||||
|
> 2025-03-25T19:22:01.325181Z INFO chinese_clip: Probability: 100.0000% Text: 一个机器人拿着蜡烛
|
||||||
|
> 2025-03-25T19:22:01.325183Z INFO chinese_clip:
|
||||||
|
>
|
||||||
|
> Results for image: candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
>
|
||||||
|
> 2025-03-25T19:22:01.325184Z INFO chinese_clip: Probability: 100.0000% Text: 一场自行车比赛
|
||||||
|
> 2025-03-25T19:22:01.325186Z INFO chinese_clip: Probability: 0.0000% Text: 两只猫的照片
|
||||||
|
> 2025-03-25T19:22:01.325187Z INFO chinese_clip: Probability: 0.0000% Text: 一个机器人拿着蜡烛
|
||||||
|
```
|
17
candle-examples/examples/convmixer/README.md
Normal file
17
candle-examples/examples/convmixer/README.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# candle-convmixer
|
||||||
|
|
||||||
|
A lightweight CNN architecture that processes image patches similar to a vision transformer, with separate spatial and channel convolutions.
|
||||||
|
|
||||||
|
ConvMixer from [Patches Are All You Need?](https://arxiv.org/pdf/2201.09792) and [ConvMixer](https://github.com/locuslab/convmixer).
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example convmixer --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
> mountain bike, all-terrain bike, off-roader: 61.75%
|
||||||
|
> unicycle, monocycle : 5.73%
|
||||||
|
> moped : 3.66%
|
||||||
|
> bicycle-built-for-two, tandem bicycle, tandem: 3.51%
|
||||||
|
> crash helmet : 0.85%
|
||||||
|
```
|
14
candle-examples/examples/csm/README.md
Normal file
14
candle-examples/examples/csm/README.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Conversational Speech Model (CSM)
|
||||||
|
|
||||||
|
CSM is a speech generation model from Sesame,
|
||||||
|
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
|
||||||
|
|
||||||
|
It can generate a conversational speech between two different speakers.
|
||||||
|
The speakers turn are delimited by the `|` character in the prompt.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example csm --features cuda -r -- \
|
||||||
|
--voices candle-examples/examples/csm/voices.safetensors \
|
||||||
|
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
|
||||||
|
```
|
||||||
|
|
243
candle-examples/examples/csm/main.rs
Normal file
243
candle-examples/examples/csm/main.rs
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::csm::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, IndexOp, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "1b")]
|
||||||
|
Csm1b,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
/// The prompt to be used for the generation, use a | to separate the speakers.
|
||||||
|
#[arg(long, default_value = "Hey how are you doing today?")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The voices to be used, in safetensors format.
|
||||||
|
#[arg(long)]
|
||||||
|
voices: String,
|
||||||
|
|
||||||
|
/// The output file using the wav format.
|
||||||
|
#[arg(long, default_value = "out.wav")]
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long, default_value_t = 0.7)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "1b")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weights: Option<String>,
|
||||||
|
|
||||||
|
/// The mimi model weight file, in safetensor format.
|
||||||
|
#[arg(long)]
|
||||||
|
mimi_weights: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id,
|
||||||
|
None => {
|
||||||
|
let name = match args.which {
|
||||||
|
Which::Csm1b => "sesame/csm-1b",
|
||||||
|
};
|
||||||
|
name.to_string()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let filenames = match args.weights {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => vec![repo.get("model.safetensors")?],
|
||||||
|
};
|
||||||
|
let tokenizer_filename = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => api
|
||||||
|
.model("meta-llama/Llama-3.2-1B".to_string())
|
||||||
|
.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let mimi_filename = match args.mimi_weights {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("kyutai/mimi".to_string())
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config: Config = match args.config {
|
||||||
|
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||||
|
None => {
|
||||||
|
let config_file = repo.get("config.json")?;
|
||||||
|
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let (mut model, device) = {
|
||||||
|
let dtype = device.bf16_default_to_f32();
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
(model, device)
|
||||||
|
};
|
||||||
|
let mut mimi_model = {
|
||||||
|
use candle_transformers::models::mimi;
|
||||||
|
let vb =
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[mimi_filename], DType::F32, &device)? };
|
||||||
|
let config = mimi::Config::v0_1(Some(32));
|
||||||
|
mimi::Model::new(config, vb)?
|
||||||
|
};
|
||||||
|
let cb = config.audio_num_codebooks;
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let voices = candle::safetensors::load(args.voices, &device)?;
|
||||||
|
let mut lp = candle_transformers::generation::LogitsProcessor::new(
|
||||||
|
args.seed,
|
||||||
|
Some(args.temperature),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
let tokens = voices
|
||||||
|
.get("tokens")
|
||||||
|
.expect("no tokens in prompt")
|
||||||
|
.to_dtype(DType::U32)?;
|
||||||
|
let mask = voices.get("mask").expect("no mask in prompt").clone();
|
||||||
|
|
||||||
|
let mut pos = 0;
|
||||||
|
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
|
pos += tokens.dim(1)?;
|
||||||
|
|
||||||
|
let mut all_pcms = vec![];
|
||||||
|
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
|
||||||
|
println!("{prompt:?}");
|
||||||
|
let speaker_idx = turn_idx % 2;
|
||||||
|
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
|
||||||
|
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
|
||||||
|
|
||||||
|
let mut generated_tokens = vec![];
|
||||||
|
loop {
|
||||||
|
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
|
pos += tokens.dim(1)?;
|
||||||
|
let is_done = frame.iter().all(|&x| x == 0);
|
||||||
|
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
|
||||||
|
print!("\rframe {pos}");
|
||||||
|
if is_done {
|
||||||
|
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
|
pos += tokens.dim(1)?;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
generated_tokens.push(tokens.clone());
|
||||||
|
}
|
||||||
|
println!();
|
||||||
|
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
||||||
|
let pcm = mimi_model.decode(&generated_tokens)?;
|
||||||
|
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||||
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||||
|
all_pcms.push(pcm);
|
||||||
|
}
|
||||||
|
let pcm = Tensor::cat(&all_pcms, 0)?;
|
||||||
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
|
println!("writing output file {}", args.out_file);
|
||||||
|
let mut output = std::fs::File::create(args.out_file)?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
BIN
candle-examples/examples/csm/voices.safetensors
Normal file
BIN
candle-examples/examples/csm/voices.safetensors
Normal file
Binary file not shown.
17
candle-examples/examples/custom-ops/README.md
Normal file
17
candle-examples/examples/custom-ops/README.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# candle-custom-ops
|
||||||
|
|
||||||
|
This example illustrates how to implement forward and backward passes for custom operations on the CPU and GPU.
|
||||||
|
The custom op in this example implements RMS normalization for the CPU and CUDA.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example custom-ops
|
||||||
|
|
||||||
|
> [[ 0., 1., 2., 3., 4., 5., 6.],
|
||||||
|
> [ 7., 8., 9., 10., 11., 12., 13.]]
|
||||||
|
> Tensor[[2, 7], f32]
|
||||||
|
> [[0.0000, 0.2773, 0.5547, 0.8320, 1.1094, 1.3867, 1.6641],
|
||||||
|
> [0.6864, 0.7845, 0.8825, 0.9806, 1.0786, 1.1767, 1.2748]]
|
||||||
|
> Tensor[[2, 7], f32]
|
||||||
|
```
|
@ -56,7 +56,7 @@ impl CustomOp1 for LayerNorm {
|
|||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::{LaunchAsync, LaunchConfig};
|
use candle::cuda_backend::cudarc::driver::{LaunchConfig, PushKernelArg};
|
||||||
use candle::cuda_backend::WrapErr;
|
use candle::cuda_backend::WrapErr;
|
||||||
let (d1, d2) = layout.shape().dims2()?;
|
let (d1, d2) = layout.shape().dims2()?;
|
||||||
let d1 = d1 as u32;
|
let d1 = d1 as u32;
|
||||||
@ -68,15 +68,19 @@ impl CustomOp1 for LayerNorm {
|
|||||||
Some((o1, o2)) => slice.slice(o1..o2),
|
Some((o1, o2)) => slice.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<f32>(elem_count) }?;
|
||||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
let func =
|
||||||
let params = (&dst, &slice, self.eps, d1, d2);
|
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||||
let cfg = LaunchConfig {
|
let cfg = LaunchConfig {
|
||||||
grid_dim: (d1, 1, 1),
|
grid_dim: (d1, 1, 1),
|
||||||
block_dim: (d2, 1, 1),
|
block_dim: (d2, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
let mut builder = func.builder();
|
||||||
|
builder.arg(&dst);
|
||||||
|
builder.arg(&slice);
|
||||||
|
candle::builder_arg!(builder, self.eps, d1, d2);
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
|
|
||||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||||
Ok((dst, layout.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
|
@ -8,7 +8,7 @@ DistilBert is used to compute the sentence embeddings for a prompt. The model we
|
|||||||
are downloaded from the hub on the first run.
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
$ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||||
@ -20,3 +20,25 @@ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
|||||||
> Tensor[[1, 7, 768], f32]
|
> Tensor[[1, 7, 768], f32]
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Masked Token
|
||||||
|
|
||||||
|
DistilBert is used to compute the top K choices for a masked token.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example distilbert -- --prompt "The capital of France is [MASK]." --top-k 10
|
||||||
|
|
||||||
|
> Input: The capital of France is [MASK].
|
||||||
|
> Predictions for [MASK] at position 6:
|
||||||
|
> 1: marseille (probability: 12.14%)
|
||||||
|
> 2: paris (probability: 10.84%)
|
||||||
|
> 3: toulouse (probability: 8.57%)
|
||||||
|
> 4: lyon (probability: 7.61%)
|
||||||
|
> 5: montpellier (probability: 5.18%)
|
||||||
|
> 6: bordeaux (probability: 4.88%)
|
||||||
|
> 7: nantes (probability: 4.82%)
|
||||||
|
> 8: lille (probability: 4.07%)
|
||||||
|
> 9: strasbourg (probability: 3.12%)
|
||||||
|
> 10: cannes (probability: 3.04%)
|
||||||
|
|
||||||
|
```
|
@ -3,15 +3,48 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
use candle_transformers::models::distilbert::{
|
||||||
|
Config, DistilBertForMaskedLM, DistilBertModel, DTYPE,
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Context, Error as E, Result};
|
||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use std::path::PathBuf;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum ModelType {
|
||||||
|
Masked(DistilBertForMaskedLM),
|
||||||
|
UnMasked(DistilBertModel),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelType {
|
||||||
|
fn device(&self) -> &Device {
|
||||||
|
match self {
|
||||||
|
ModelType::Masked(model) => &model.bert.device,
|
||||||
|
ModelType::UnMasked(model) => &model.device,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
ModelType::Masked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
||||||
|
ModelType::UnMasked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "distilbert")]
|
||||||
|
DistilBert,
|
||||||
|
|
||||||
|
#[value(name = "distilbertformaskedlm")]
|
||||||
|
DistilbertForMaskedLM,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -23,10 +56,14 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "distilbert")]
|
||||||
|
model: Which,
|
||||||
|
|
||||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
/// Revision or branch
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
@ -42,94 +79,246 @@ struct Args {
|
|||||||
#[arg(long, default_value = "1")]
|
#[arg(long, default_value = "1")]
|
||||||
n: usize,
|
n: usize,
|
||||||
|
|
||||||
/// L2 normalization for embeddings.
|
/// Number of top predictions to show for each mask
|
||||||
#[arg(long, default_value = "true")]
|
#[arg(long, default_value = "5")]
|
||||||
normalize_embeddings: bool,
|
top_k: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
fn build_model_and_tokenizer(&self) -> Result<(ModelType, Tokenizer)> {
|
||||||
let device = candle_examples::device(self.cpu)?;
|
let device = candle_examples::device(self.cpu)?;
|
||||||
|
|
||||||
|
let (model_id, revision) = self.resolve_model_and_revision();
|
||||||
|
let (config_path, tokenizer_path, weights_path) =
|
||||||
|
self.download_model_files(&model_id, &revision)?;
|
||||||
|
|
||||||
|
let config = std::fs::read_to_string(config_path)?;
|
||||||
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let vb = self.load_variables(&weights_path, &device)?;
|
||||||
|
let model = self.create_model(&config, vb)?;
|
||||||
|
|
||||||
|
Ok((model, tokenizer))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_model_and_revision(&self) -> (String, String) {
|
||||||
let default_model = "distilbert-base-uncased".to_string();
|
let default_model = "distilbert-base-uncased".to_string();
|
||||||
let default_revision = "main".to_string();
|
let default_revision = "main".to_string();
|
||||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
|
||||||
|
match (self.model_id.clone(), self.revision.clone()) {
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
(Some(model_id), None) => (model_id, default_revision),
|
||||||
(None, Some(revision)) => (default_model, revision),
|
(None, Some(revision)) => (default_model, revision),
|
||||||
(None, None) => (default_model, default_revision),
|
(None, None) => (default_model, default_revision),
|
||||||
};
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
fn download_model_files(
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
&self,
|
||||||
let api = Api::new()?;
|
model_id: &str,
|
||||||
let api = api.repo(repo);
|
revision: &str,
|
||||||
let config = api.get("config.json")?;
|
) -> Result<(PathBuf, PathBuf, PathBuf)> {
|
||||||
let tokenizer = api.get("tokenizer.json")?;
|
let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
|
||||||
let weights = if self.use_pth {
|
let api = Api::new()?;
|
||||||
api.get("pytorch_model.bin")?
|
let api = api.repo(repo);
|
||||||
} else {
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
};
|
|
||||||
(config, tokenizer, weights)
|
|
||||||
};
|
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let vb = if self.use_pth {
|
let config = api.get("config.json")?;
|
||||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
|
let weights = if self.use_pth {
|
||||||
|
api.get("pytorch_model.bin")?
|
||||||
} else {
|
} else {
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
api.get("model.safetensors")?
|
||||||
};
|
};
|
||||||
let model = DistilBertModel::load(vb, &config)?;
|
|
||||||
Ok((model, tokenizer))
|
Ok((config, tokenizer, weights))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder> {
|
||||||
|
if self.use_pth {
|
||||||
|
Ok(VarBuilder::from_pth(weights_path, DTYPE, device)?)
|
||||||
|
} else {
|
||||||
|
Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device)? })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
||||||
|
match self.model {
|
||||||
|
Which::DistilbertForMaskedLM => {
|
||||||
|
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
|
||||||
|
}
|
||||||
|
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
fn main() -> Result<()> {
|
||||||
let mask: Vec<_> = (0..size)
|
let args = Args::parse();
|
||||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
let _guard = setup_tracing(&args);
|
||||||
.collect();
|
|
||||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
let (model, tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
|
let device = model.device();
|
||||||
|
|
||||||
|
let (token_ids, mask) = prepare_inputs(&args, &tokenizer, device)?;
|
||||||
|
let output = model.forward(&token_ids, &mask)?;
|
||||||
|
|
||||||
|
process_output(&model, &output, &token_ids, &tokenizer, &args)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn setup_tracing(args: &Args) -> Option<impl Drop> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
if args.tracing {
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
println!("tracing...");
|
println!("tracing...");
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
Some(guard)
|
Some(guard)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
}
|
||||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
}
|
||||||
let device = &model.device;
|
|
||||||
|
|
||||||
let tokenizer = tokenizer
|
fn prepare_inputs(args: &Args, tokenizer: &Tokenizer, device: &Device) -> Result<(Tensor, Tensor)> {
|
||||||
|
let mut binding = tokenizer.clone();
|
||||||
|
let tokenizer_configured = binding
|
||||||
.with_padding(None)
|
.with_padding(None)
|
||||||
.with_truncation(None)
|
.with_truncation(None)
|
||||||
.map_err(E::msg)?;
|
.map_err(E::msg)?;
|
||||||
let tokens = tokenizer
|
|
||||||
.encode(args.prompt, true)
|
let tokens = tokenizer_configured
|
||||||
|
.encode(args.prompt.clone(), true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
let mask = get_mask(tokens.len(), device);
|
|
||||||
|
|
||||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
let mask = match args.model {
|
||||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
Which::DistilbertForMaskedLM => attention_mask_maskedlm(tokenizer, &args.prompt, device)?,
|
||||||
|
Which::DistilBert => attention_mask(tokens.len(), device)?,
|
||||||
|
};
|
||||||
|
|
||||||
let ys = model.forward(&token_ids, &mask)?;
|
println!("token_ids: {:?}", token_ids.to_vec2::<u32>()?);
|
||||||
println!("{ys}");
|
|
||||||
|
Ok((token_ids, mask))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_output(
|
||||||
|
model: &ModelType,
|
||||||
|
output: &Tensor,
|
||||||
|
token_ids: &Tensor,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
args: &Args,
|
||||||
|
) -> Result<()> {
|
||||||
|
match model {
|
||||||
|
ModelType::UnMasked(_) => {
|
||||||
|
println!("embeddings");
|
||||||
|
println!("{output}");
|
||||||
|
}
|
||||||
|
ModelType::Masked(_) => {
|
||||||
|
process_masked_output(output, token_ids, tokenizer, args)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
fn process_masked_output(
|
||||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
output: &Tensor,
|
||||||
|
token_ids: &Tensor,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
args: &Args,
|
||||||
|
) -> Result<()> {
|
||||||
|
let input_ids_vec = token_ids.to_vec2::<u32>()?;
|
||||||
|
let mask_token_id = tokenizer
|
||||||
|
.token_to_id("[MASK]")
|
||||||
|
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
||||||
|
|
||||||
|
println!("\nInput: {}", args.prompt);
|
||||||
|
|
||||||
|
for (token_idx, &token_id) in input_ids_vec[0].iter().enumerate() {
|
||||||
|
if token_id == mask_token_id {
|
||||||
|
println!("Predictions for [MASK] at position {}:", token_idx);
|
||||||
|
|
||||||
|
let pos_logits = output.get(0)?.get(token_idx)?;
|
||||||
|
let probs = candle_nn::ops::softmax(&pos_logits, 0)?;
|
||||||
|
let (top_values, top_indices) = get_top_k(&probs, args.top_k)?;
|
||||||
|
|
||||||
|
let values = top_values.to_vec1::<f32>()?;
|
||||||
|
let indices = top_indices.to_vec1::<u32>()?;
|
||||||
|
|
||||||
|
for (i, (&token_id, &prob)) in indices.iter().zip(values.iter()).enumerate() {
|
||||||
|
let token = tokenizer.decode(&[token_id], false).map_err(E::msg)?;
|
||||||
|
println!(
|
||||||
|
" {}: {:15} (probability: {:.2}%)",
|
||||||
|
i + 1,
|
||||||
|
token,
|
||||||
|
prob * 100.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_top_k(tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
|
||||||
|
let n = tensor.dims().iter().product::<usize>();
|
||||||
|
let k = std::cmp::min(k, n);
|
||||||
|
|
||||||
|
let values = tensor.to_vec1::<f32>()?;
|
||||||
|
let mut value_indices: Vec<(f32, usize)> = values
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, val)| (val, idx))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
|
||||||
|
let top_k_values: Vec<f32> = value_indices.iter().take(k).map(|(val, _)| *val).collect();
|
||||||
|
let top_k_indices: Vec<u32> = value_indices
|
||||||
|
.iter()
|
||||||
|
.take(k)
|
||||||
|
.map(|(_, idx)| *idx as u32)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let device = tensor.device();
|
||||||
|
let top_values = Tensor::from_vec(top_k_values, (k,), device)?;
|
||||||
|
let top_indices = Tensor::from_vec(top_k_indices, (k,), device)?;
|
||||||
|
|
||||||
|
Ok((top_values, top_indices))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attention_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..size)
|
||||||
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
Ok(Tensor::from_slice(&mask, (size, size), device)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attention_mask_maskedlm(tokenizer: &Tokenizer, input: &str, device: &Device) -> Result<Tensor> {
|
||||||
|
let tokens = tokenizer.encode(input, true).map_err(E::msg)?;
|
||||||
|
let seq_len = tokens.get_attention_mask().to_vec().len();
|
||||||
|
|
||||||
|
let mask_token_id = tokenizer
|
||||||
|
.token_to_id("[MASK]")
|
||||||
|
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
||||||
|
|
||||||
|
let mut attention_mask_vec = Vec::with_capacity(seq_len * seq_len);
|
||||||
|
|
||||||
|
let ids = tokens.get_ids();
|
||||||
|
for _ in 0..seq_len {
|
||||||
|
for id in ids.iter() {
|
||||||
|
let mask_value = if id == &mask_token_id { 1u8 } else { 0u8 };
|
||||||
|
attention_mask_vec.push(mask_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let shape = (1, 1, seq_len, seq_len);
|
||||||
|
let mask = Tensor::from_vec(attention_mask_vec, shape, device)?;
|
||||||
|
|
||||||
|
Ok(mask)
|
||||||
}
|
}
|
||||||
|
15
candle-examples/examples/efficientnet/README.md
Normal file
15
candle-examples/examples/efficientnet/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# candle-efficientnet
|
||||||
|
|
||||||
|
Demonstrates a Candle implementation of EfficientNet for image classification based on ImageNet classes.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example efficientnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which b1
|
||||||
|
|
||||||
|
> bicycle-built-for-two, tandem bicycle, tandem: 45.85%
|
||||||
|
> mountain bike, all-terrain bike, off-roader: 30.45%
|
||||||
|
> crash helmet : 2.58%
|
||||||
|
> unicycle, monocycle : 2.21%
|
||||||
|
> tricycle, trike, velocipede: 1.53%
|
||||||
|
```
|
@ -1,3 +1,10 @@
|
|||||||
# candle-falcon
|
# candle-falcon
|
||||||
|
|
||||||
Falcon is a general large language model.
|
Falcon is a general large language model.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
Make sure to include the `--use-f32` flag if using CPU, because there isn't a BFloat16 implementation yet.
|
||||||
|
```
|
||||||
|
cargo run --example falcon --release -- --prompt "Flying monkeys are" --use-f32
|
||||||
|
```
|
@ -12,7 +12,7 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
|
|||||||
|
|
||||||
** Running with ~cpu~
|
** Running with ~cpu~
|
||||||
#+begin_src shell
|
#+begin_src shell
|
||||||
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
|
cargo run --example glm4 --release -- --cpu --prompt "Hello world"
|
||||||
#+end_src
|
#+end_src
|
||||||
|
|
||||||
** Output Example
|
** Output Example
|
||||||
|
@ -46,7 +46,7 @@ impl TextGeneration {
|
|||||||
Sampling::ArgMax
|
Sampling::ArgMax
|
||||||
} else {
|
} else {
|
||||||
match (top_k, top_p) {
|
match (top_k, top_p) {
|
||||||
(None, None) => Sampling::All { temperature },
|
(None, None) => Sampling::GumbelSoftmax { temperature },
|
||||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
11
candle-examples/examples/llama/README.md
Normal file
11
candle-examples/examples/llama/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# candle-llama
|
||||||
|
|
||||||
|
Candle implementations of various Llama based architectures.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example llama -- --prompt "Machine learning is " --which v32-3b-instruct
|
||||||
|
|
||||||
|
> Machine learning is the part of computer science which deals with the development of algorithms and
|
||||||
|
```
|
@ -21,7 +21,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
(self.d_model + 15) / 16
|
self.d_model.div_ceil(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_conv(&self) -> usize {
|
fn d_conv(&self) -> usize {
|
||||||
|
@ -12,6 +12,6 @@ would only work for inference.
|
|||||||
## Running the example
|
## Running the example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
$ cargo run --example mamba --release -- --prompt "Mamba is the"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -18,21 +18,19 @@ I know you are waiting for me. I will go through the forest, I will go through t
|
|||||||
mountain. I cannot stay far from you any longer.</s>
|
mountain. I cannot stay far from you any longer.</s>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Changing model and language pairs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example marian-mt --release -- --text "hello, how are you." --which base --language-pair en-zh
|
||||||
|
|
||||||
|
你好,你好吗?
|
||||||
|
```
|
||||||
|
|
||||||
## Generating the tokenizer.json files
|
## Generating the tokenizer.json files
|
||||||
|
|
||||||
You can use the following script to generate the `tokenizer.json` config files
|
The tokenizer for each `marian-mt` model was trained independently,
|
||||||
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
|
meaning each new model needs unique tokenizer encoders and decoders.
|
||||||
packages to be install and use the `convert_slow_tokenizer.py` script from this
|
You can use the `./python/convert_slow_tokenizer.py` script in this directory to generate
|
||||||
directory.
|
the `tokenizer.json` config files from the hf-hub repos.
|
||||||
|
The script requires all the packages in `./python/requirements.txt` or `./python/uv.lock`
|
||||||
```python
|
to be installed, and has only been tested for `python 3.12.7`.
|
||||||
from convert_slow_tokenizer import MarianConverter
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
|
||||||
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
|
||||||
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
|
|
||||||
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
|
||||||
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
|
|
||||||
```
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -20,6 +20,22 @@ enum Which {
|
|||||||
Big,
|
Big,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum LanguagePair {
|
||||||
|
#[value(name = "fr-en")]
|
||||||
|
FrEn,
|
||||||
|
#[value(name = "en-zh")]
|
||||||
|
EnZh,
|
||||||
|
#[value(name = "en-hi")]
|
||||||
|
EnHi,
|
||||||
|
#[value(name = "en-es")]
|
||||||
|
EnEs,
|
||||||
|
#[value(name = "en-fr")]
|
||||||
|
EnFr,
|
||||||
|
#[value(name = "en-ru")]
|
||||||
|
EnRu,
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Maybe add support for the conditional prompt.
|
// TODO: Maybe add support for the conditional prompt.
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -36,6 +52,10 @@ struct Args {
|
|||||||
#[arg(long, default_value = "big")]
|
#[arg(long, default_value = "big")]
|
||||||
which: Which,
|
which: Which,
|
||||||
|
|
||||||
|
// Choose which language pair to use
|
||||||
|
#[arg(long, default_value = "fr-en")]
|
||||||
|
language_pair: LanguagePair,
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
/// Run on CPU rather than on GPU.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
@ -53,21 +73,43 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
use hf_hub::api::sync::Api;
|
use hf_hub::api::sync::Api;
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let config = match args.which {
|
let config = match (args.which, args.language_pair) {
|
||||||
Which::Base => marian::Config::opus_mt_fr_en(),
|
(Which::Base, LanguagePair::FrEn) => marian::Config::opus_mt_fr_en(),
|
||||||
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
|
(Which::Big, LanguagePair::FrEn) => marian::Config::opus_mt_tc_big_fr_en(),
|
||||||
|
(Which::Base, LanguagePair::EnZh) => marian::Config::opus_mt_en_zh(),
|
||||||
|
(Which::Base, LanguagePair::EnHi) => marian::Config::opus_mt_en_hi(),
|
||||||
|
(Which::Base, LanguagePair::EnEs) => marian::Config::opus_mt_en_es(),
|
||||||
|
(Which::Base, LanguagePair::EnFr) => marian::Config::opus_mt_fr_en(),
|
||||||
|
(Which::Base, LanguagePair::EnRu) => marian::Config::opus_mt_en_ru(),
|
||||||
|
(Which::Big, lp) => anyhow::bail!("big is not supported for language pair {lp:?}"),
|
||||||
|
};
|
||||||
|
let tokenizer_default_repo = match args.language_pair {
|
||||||
|
LanguagePair::FrEn => "lmz/candle-marian",
|
||||||
|
LanguagePair::EnZh
|
||||||
|
| LanguagePair::EnHi
|
||||||
|
| LanguagePair::EnEs
|
||||||
|
| LanguagePair::EnFr
|
||||||
|
| LanguagePair::EnRu => "KeighBee/candle-marian",
|
||||||
};
|
};
|
||||||
let tokenizer = {
|
let tokenizer = {
|
||||||
let tokenizer = match args.tokenizer {
|
let tokenizer = match args.tokenizer {
|
||||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
None => {
|
None => {
|
||||||
let name = match args.which {
|
let filename = match (args.which, args.language_pair) {
|
||||||
Which::Base => "tokenizer-marian-base-fr.json",
|
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-fr.json",
|
||||||
Which::Big => "tokenizer-marian-fr.json",
|
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-fr.json",
|
||||||
|
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-en.json",
|
||||||
|
(Which::Big, lp) => {
|
||||||
|
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Api::new()?
|
Api::new()?
|
||||||
.model("lmz/candle-marian".to_string())
|
.model(tokenizer_default_repo.to_string())
|
||||||
.get(name)?
|
.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
@ -77,13 +119,21 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let tokenizer = match args.tokenizer_dec {
|
let tokenizer = match args.tokenizer_dec {
|
||||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
||||||
None => {
|
None => {
|
||||||
let name = match args.which {
|
let filename = match (args.which, args.language_pair) {
|
||||||
Which::Base => "tokenizer-marian-base-en.json",
|
(Which::Base, LanguagePair::FrEn) => "tokenizer-marian-base-en.json",
|
||||||
Which::Big => "tokenizer-marian-en.json",
|
(Which::Big, LanguagePair::FrEn) => "tokenizer-marian-en.json",
|
||||||
|
(Which::Base, LanguagePair::EnZh) => "tokenizer-marian-base-en-zh-zh.json",
|
||||||
|
(Which::Base, LanguagePair::EnHi) => "tokenizer-marian-base-en-hi-hi.json",
|
||||||
|
(Which::Base, LanguagePair::EnEs) => "tokenizer-marian-base-en-es-es.json",
|
||||||
|
(Which::Base, LanguagePair::EnFr) => "tokenizer-marian-base-en-fr-fr.json",
|
||||||
|
(Which::Base, LanguagePair::EnRu) => "tokenizer-marian-base-en-ru-ru.json",
|
||||||
|
(Which::Big, lp) => {
|
||||||
|
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Api::new()?
|
Api::new()?
|
||||||
.model("lmz/candle-marian".to_string())
|
.model(tokenizer_default_repo.to_string())
|
||||||
.get(name)?
|
.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
@ -94,18 +144,48 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let vb = {
|
let vb = {
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
None => match args.which {
|
None => {
|
||||||
Which::Base => Api::new()?
|
let api = Api::new()?;
|
||||||
.repo(hf_hub::Repo::with_revision(
|
let api = match (args.which, args.language_pair) {
|
||||||
|
(Which::Base, LanguagePair::FrEn) => api.repo(hf_hub::Repo::with_revision(
|
||||||
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
"Helsinki-NLP/opus-mt-fr-en".to_string(),
|
||||||
hf_hub::RepoType::Model,
|
hf_hub::RepoType::Model,
|
||||||
"refs/pr/4".to_string(),
|
"refs/pr/4".to_string(),
|
||||||
))
|
)),
|
||||||
.get("model.safetensors")?,
|
(Which::Big, LanguagePair::FrEn) => {
|
||||||
Which::Big => Api::new()?
|
api.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
||||||
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
|
}
|
||||||
.get("model.safetensors")?,
|
(Which::Base, LanguagePair::EnZh) => api.repo(hf_hub::Repo::with_revision(
|
||||||
},
|
"Helsinki-NLP/opus-mt-en-zh".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/13".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnHi) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-hi".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/3".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnEs) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-es".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/4".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnFr) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-fr".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/9".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Base, LanguagePair::EnRu) => api.repo(hf_hub::Repo::with_revision(
|
||||||
|
"Helsinki-NLP/opus-mt-en-ru".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/7".to_string(),
|
||||||
|
)),
|
||||||
|
(Which::Big, lp) => {
|
||||||
|
anyhow::bail!("big is not supported for language pair {lp:?}")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.convert_slow_tokenizer import SpmConverter, requires_backends, import_protobuf
|
||||||
|
|
||||||
|
class MarianConverter(SpmConverter):
|
||||||
|
def __init__(self, *args, index: int = 0):
|
||||||
|
requires_backends(self, "protobuf")
|
||||||
|
|
||||||
|
super(SpmConverter, self).__init__(*args)
|
||||||
|
|
||||||
|
# from .utils import sentencepiece_model_pb2 as model_pb2
|
||||||
|
model_pb2 = import_protobuf()
|
||||||
|
|
||||||
|
m = model_pb2.ModelProto()
|
||||||
|
print(self.original_tokenizer.spm_files)
|
||||||
|
with open(self.original_tokenizer.spm_files[index], "rb") as f:
|
||||||
|
m.ParseFromString(f.read())
|
||||||
|
self.proto = m
|
||||||
|
print(self.original_tokenizer)
|
||||||
|
#with open(self.original_tokenizer.vocab_path, "r") as f:
|
||||||
|
dir_path = Path(self.original_tokenizer.spm_files[0]).parents[0]
|
||||||
|
with open(dir_path / "vocab.json", "r") as f:
|
||||||
|
import json
|
||||||
|
self._vocab = json.load(f)
|
||||||
|
|
||||||
|
if self.proto.trainer_spec.byte_fallback:
|
||||||
|
if not getattr(self, "handle_byte_fallback", None):
|
||||||
|
warnings.warn(
|
||||||
|
"The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
|
||||||
|
" which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
|
||||||
|
" tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
|
||||||
|
"unknown tokens into a sequence of byte tokens matching the original piece of text."
|
||||||
|
)
|
||||||
|
|
||||||
|
def vocab(self, proto):
|
||||||
|
vocab_size = max(self._vocab.values()) + 1
|
||||||
|
vocab = [("<NIL>", -100) for _ in range(vocab_size)]
|
||||||
|
for piece in proto.pieces:
|
||||||
|
try:
|
||||||
|
index = self._vocab[piece.piece]
|
||||||
|
except Exception:
|
||||||
|
print(f"Ignored missing piece {piece.piece}")
|
||||||
|
vocab[index] = (piece.piece, piece.score)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
|
||||||
|
fast_tokenizer.save("tokenizer-marian-base-fr.json")
|
||||||
|
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
|
||||||
|
fast_tokenizer.save("tokenizer-marian-base-en.json")
|
22
candle-examples/examples/marian-mt/python/requirements.txt
Normal file
22
candle-examples/examples/marian-mt/python/requirements.txt
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
certifi==2025.1.31
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
click==8.1.8
|
||||||
|
filelock==3.18.0
|
||||||
|
fsspec==2025.3.2
|
||||||
|
huggingface-hub==0.30.1
|
||||||
|
idna==3.10
|
||||||
|
joblib==1.4.2
|
||||||
|
numpy==2.2.4
|
||||||
|
packaging==24.2
|
||||||
|
protobuf==6.30.2
|
||||||
|
pyyaml==6.0.2
|
||||||
|
regex==2024.11.6
|
||||||
|
requests==2.32.3
|
||||||
|
sacremoses==0.1.1
|
||||||
|
safetensors==0.5.3
|
||||||
|
sentencepiece==0.2.0
|
||||||
|
tokenizers==0.21.1
|
||||||
|
tqdm==4.67.1
|
||||||
|
transformers==4.50.3
|
||||||
|
typing-extensions==4.13.0
|
||||||
|
urllib3==2.3.0
|
@ -13,6 +13,6 @@ Note that the current candle implementation suffers from some limitations as of
|
|||||||
## Run an example
|
## Run an example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example metavoice --release -- \\
|
cargo run --example metavoice --release -- \
|
||||||
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
|
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
|
||||||
```
|
```
|
||||||
|
16
candle-examples/examples/mnist-training/README.md
Normal file
16
candle-examples/examples/mnist-training/README.md
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# candle-mnist-training
|
||||||
|
|
||||||
|
Training a 2 layer MLP on mnist in Candle.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example mnist-training --features candle-datasets
|
||||||
|
|
||||||
|
> train-images: [60000, 784]
|
||||||
|
> train-labels: [60000]
|
||||||
|
> test-images: [10000, 784]
|
||||||
|
> test-labels: [10000]
|
||||||
|
> 1 train loss: 2.30265 test acc: 68.08%
|
||||||
|
> 2 train loss: 1.50815 test acc: 60.77%
|
||||||
|
```
|
@ -12,7 +12,7 @@ $ wget https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jp
|
|||||||
|
|
||||||
Now you can run Moondream from the `candle-examples` crate:
|
Now you can run Moondream from the `candle-examples` crate:
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example moondream --release -- --prompt "What is the girl eating?" --image "./demo-1.jpg"
|
$ cargo run --example moondream --release -- --prompt "Describe the people behind the bikers?" --image "candle-examples/examples/yolo-v8/assets/bike.jpg"
|
||||||
|
|
||||||
avavx: false, neon: true, simd128: false, f16c: false
|
avavx: false, neon: true, simd128: false, f16c: false
|
||||||
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
temp: 0.00 repeat-penalty: 1.00 repeat-last-n: 64
|
||||||
|
20
candle-examples/examples/musicgen/README.md
Normal file
20
candle-examples/examples/musicgen/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-musicgen
|
||||||
|
|
||||||
|
Candle implementation of musicgen from [Simple and Controllable Music Generation](https://arxiv.org/pdf/2306.05284).
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example musicgen -- --prompt "90s rock song with loud guitars and heavy drums"
|
||||||
|
|
||||||
|
> tokens: [2777, 7, 2480, 2324, 28, 8002, 5507, 7, 11, 2437, 5253, 7, 1]
|
||||||
|
> Tensor[dims 1, 13; u32]
|
||||||
|
> [[[ 0.0902, 0.1256, -0.0585, ..., 0.1057, -0.5141, -0.4675],
|
||||||
|
> [ 0.1972, -0.0268, -0.3368, ..., -0.0495, -0.3597, -0.3940],
|
||||||
|
> [-0.0855, -0.0007, 0.2225, ..., -0.2804, -0.5360, -0.2436],
|
||||||
|
> ...
|
||||||
|
> [ 0.0515, 0.0235, -0.3855, ..., -0.4728, -0.6858, -0.2923],
|
||||||
|
> [-0.3728, -0.1442, -0.1179, ..., -0.4388, -0.0287, -0.3242],
|
||||||
|
> [ 0.0163, 0.0012, -0.0020, ..., 0.0142, 0.0173, -0.0103]]]
|
||||||
|
> Tensor[[1, 13, 768], f32]
|
||||||
|
```
|
14
candle-examples/examples/orpheus/README.md
Normal file
14
candle-examples/examples/orpheus/README.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Orpheus
|
||||||
|
|
||||||
|
Orpheus is a 3B text-to-speech model based on Llama.
|
||||||
|
|
||||||
|
- Weights on HuggingFace
|
||||||
|
[canopylabs/orpheus-3b-0.1-ft](https://huggingface.co/canopylabs/orpheus-3b-0.1-ft).
|
||||||
|
- Code on GitHub [canopyai/Orpheus-TTS](https://github.com/canopyai/Orpheus-TTS).
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example orpheus --features cuda -r
|
||||||
|
```
|
||||||
|
|
||||||
|
|
329
candle-examples/examples/orpheus/main.rs
Normal file
329
candle-examples/examples/orpheus/main.rs
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle::{DType, Device, IndexOp, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::llama::{Cache, Llama, LlamaConfig};
|
||||||
|
use candle_transformers::models::snac::{Config as SnacConfig, Model as SnacModel};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43
|
||||||
|
const STOP_TOKEN_ID: u32 = 128258;
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Display the token for the specified prompt.
|
||||||
|
#[arg(long)]
|
||||||
|
verbose_prompt: bool,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "Hey, how are you doing today?")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long, default_value_t = 0.6)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
/// The output wav file.
|
||||||
|
#[arg(long, default_value = "out.wav")]
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "3b-0.1-ft")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "tara")]
|
||||||
|
voice: Voice,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Voice {
|
||||||
|
#[value(name = "tara")]
|
||||||
|
Tara,
|
||||||
|
#[value(name = "leah")]
|
||||||
|
Leah,
|
||||||
|
#[value(name = "jess")]
|
||||||
|
Jess,
|
||||||
|
#[value(name = "leo")]
|
||||||
|
Leo,
|
||||||
|
#[value(name = "dan")]
|
||||||
|
Dan,
|
||||||
|
#[value(name = "mia")]
|
||||||
|
Mia,
|
||||||
|
#[value(name = "zac")]
|
||||||
|
Zac,
|
||||||
|
#[value(name = "zoe")]
|
||||||
|
Zoe,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Voice {
|
||||||
|
fn as_str(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Voice::Tara => "tara",
|
||||||
|
Voice::Leah => "leah",
|
||||||
|
Voice::Jess => "jess",
|
||||||
|
Voice::Leo => "leo",
|
||||||
|
Voice::Dan => "dan",
|
||||||
|
Voice::Mia => "mia",
|
||||||
|
Voice::Zac => "zac",
|
||||||
|
Voice::Zoe => "zoe",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "3b-0.1-ft")]
|
||||||
|
ThreeB0_1Ft,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
let prompt = args.prompt.clone();
|
||||||
|
let mut model = Model::load(args)?;
|
||||||
|
model.run(&prompt)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Model {
|
||||||
|
model: Llama,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
logits_processor: candle_transformers::generation::LogitsProcessor,
|
||||||
|
cache: Cache,
|
||||||
|
device: Device,
|
||||||
|
verbose_prompt: bool,
|
||||||
|
snac: SnacModel,
|
||||||
|
out_file: String,
|
||||||
|
voice: Voice,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_snac(device: &Device) -> Result<SnacModel> {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let m = api.model("hubertsiuzdak/snac_24khz".to_string());
|
||||||
|
let config = m.get("config.json")?;
|
||||||
|
let config: SnacConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
|
||||||
|
let m = api.model("lmz/candle-snac".to_string());
|
||||||
|
let model = m.get("snac_24khz.safetensors")?;
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, device)? };
|
||||||
|
let model = SnacModel::new(&config, vb)?;
|
||||||
|
Ok(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
fn load(args: Args) -> Result<Self> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => match args.which {
|
||||||
|
Which::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let revision = match args.revision {
|
||||||
|
Some(r) => r,
|
||||||
|
None => "main".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(hf_hub::Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
revision,
|
||||||
|
));
|
||||||
|
let model_files = match args.model_file {
|
||||||
|
Some(m) => vec![m.into()],
|
||||||
|
None => match args.which {
|
||||||
|
Which::ThreeB0_1Ft => {
|
||||||
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let config = match args.config_file {
|
||||||
|
Some(m) => m.into(),
|
||||||
|
None => repo.get("config.json")?,
|
||||||
|
};
|
||||||
|
let tokenizer = match args.tokenizer_file {
|
||||||
|
Some(m) => m.into(),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = device.bf16_default_to_f32();
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device)? };
|
||||||
|
let config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
|
||||||
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
let model = Llama::load(vb, &config)?;
|
||||||
|
let logits_processor = {
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k.as_ref(), args.top_p.as_ref()) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(&k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(&p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(&k), Some(&p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
let cache = Cache::new(true, dtype, &config, &device)?;
|
||||||
|
let snac = load_snac(&device)?;
|
||||||
|
Ok(Self {
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
logits_processor,
|
||||||
|
cache,
|
||||||
|
device,
|
||||||
|
verbose_prompt: args.verbose_prompt,
|
||||||
|
snac,
|
||||||
|
voice: args.voice,
|
||||||
|
out_file: args.out_file,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str) -> Result<()> {
|
||||||
|
println!("running the model on '{}'", prompt);
|
||||||
|
let device = &self.device;
|
||||||
|
let prompt = format!("{voice}: {prompt}", voice = self.voice.as_str());
|
||||||
|
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
|
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82
|
||||||
|
let mut tokens = [
|
||||||
|
&[128259],
|
||||||
|
tokens.get_ids(),
|
||||||
|
&[128009, 128260, 128261, 128257],
|
||||||
|
]
|
||||||
|
.concat();
|
||||||
|
if self.verbose_prompt {
|
||||||
|
println!("{:?}", tokens);
|
||||||
|
}
|
||||||
|
let mut cache = self.cache.clone();
|
||||||
|
|
||||||
|
println!("starting the inference loop");
|
||||||
|
let mut index_pos = 0;
|
||||||
|
let mut audio_tokens = vec![];
|
||||||
|
for index in 0..2000 {
|
||||||
|
let (context_size, context_index) = if index > 0 {
|
||||||
|
(1, index_pos)
|
||||||
|
} else {
|
||||||
|
(tokens.len(), 0)
|
||||||
|
};
|
||||||
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
|
let input = Tensor::new(ctxt, device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, context_index, &mut cache)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
if let Some(tok) = self.tokenizer.id_to_token(next_token) {
|
||||||
|
match tok.strip_prefix("<custom_token_") {
|
||||||
|
Some(tok) => match tok.strip_suffix('>') {
|
||||||
|
Some(tok) => {
|
||||||
|
let tok = tok.parse::<u32>()?;
|
||||||
|
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63
|
||||||
|
let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096);
|
||||||
|
audio_tokens.push(tok);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
println!("{index}: unexpected custom token {next_token} {tok}");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => {
|
||||||
|
println!("{index}: unexpected token {next_token} {tok}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if next_token == STOP_TOKEN_ID {
|
||||||
|
println!("reached stop token");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tokens.push(next_token);
|
||||||
|
}
|
||||||
|
println!("generated {} audio tokens", audio_tokens.len());
|
||||||
|
let mut codes0 = vec![];
|
||||||
|
let mut codes1 = vec![];
|
||||||
|
let mut codes2 = vec![];
|
||||||
|
for audio_tokens in audio_tokens.chunks_exact(7) {
|
||||||
|
codes0.push(audio_tokens[0]);
|
||||||
|
for i in [1, 4] {
|
||||||
|
codes1.push(audio_tokens[i]);
|
||||||
|
}
|
||||||
|
for i in [2, 3, 5, 6] {
|
||||||
|
codes2.push(audio_tokens[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?;
|
||||||
|
let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?;
|
||||||
|
let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?;
|
||||||
|
let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?;
|
||||||
|
println!("decoded to pcm {pcm:?}");
|
||||||
|
let mut output = std::fs::File::create(&self.out_file)?;
|
||||||
|
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24000)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
20
candle-examples/examples/quantized-phi/README.md
Normal file
20
candle-examples/examples/quantized-phi/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# candle-quantized-phi
|
||||||
|
|
||||||
|
Candle implementation of various quantized Phi models.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example quantized-phi --release -- --prompt "The best thing about coding in rust is "
|
||||||
|
|
||||||
|
> - it's memory safe (without you having to worry too much)
|
||||||
|
> - the borrow checker is really smart and will catch your mistakes for free, making them show up as compile errors instead of segfaulting in runtime.
|
||||||
|
>
|
||||||
|
> This alone make me prefer using rust over c++ or go, python/Cython etc.
|
||||||
|
>
|
||||||
|
> The major downside I can see now:
|
||||||
|
> - it's slower than other languages (viz: C++) and most importantly lack of libraries to leverage existing work done by community in that language. There are so many useful machine learning libraries available for c++, go, python etc but none for Rust as far as I am aware of on the first glance.
|
||||||
|
> - there aren't a lot of production ready projects which also makes it very hard to start new one (given my background)
|
||||||
|
>
|
||||||
|
> Another downside:
|
||||||
|
```
|
@ -27,6 +27,8 @@ enum Which {
|
|||||||
W2_7b,
|
W2_7b,
|
||||||
#[value(name = "72b")]
|
#[value(name = "72b")]
|
||||||
W2_72b,
|
W2_72b,
|
||||||
|
#[value(name = "deepseekr1-qwen7b")]
|
||||||
|
DeepseekR1Qwen7B,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -102,6 +104,7 @@ impl Args {
|
|||||||
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
|
Which::W2_1_5b => "Qwen/Qwen2-1.5B-Instruct",
|
||||||
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
|
Which::W2_7b => "Qwen/Qwen2-7B-Instruct",
|
||||||
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
|
Which::W2_72b => "Qwen/Qwen2-72B-Instruct",
|
||||||
|
Which::DeepseekR1Qwen7B => "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
||||||
};
|
};
|
||||||
let api = api.model(repo.to_string());
|
let api = api.model(repo.to_string());
|
||||||
api.get("tokenizer.json")?
|
api.get("tokenizer.json")?
|
||||||
@ -135,6 +138,11 @@ impl Args {
|
|||||||
"qwen2-72b-instruct-q4_0.gguf",
|
"qwen2-72b-instruct-q4_0.gguf",
|
||||||
"main",
|
"main",
|
||||||
),
|
),
|
||||||
|
Which::DeepseekR1Qwen7B => (
|
||||||
|
"unsloth/DeepSeek-R1-Distill-Qwen-7B-GGUF",
|
||||||
|
"DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf",
|
||||||
|
"main",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
api.repo(hf_hub::Repo::with_revision(
|
api.repo(hf_hub::Repo::with_revision(
|
||||||
@ -211,11 +219,15 @@ fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let tokenizer = args.tokenizer()?;
|
let tokenizer = args.tokenizer()?;
|
||||||
let mut tos = TokenOutputStream::new(tokenizer);
|
let mut tos = TokenOutputStream::new(tokenizer);
|
||||||
let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
let prompt_str = args
|
||||||
let prompt_str = format!(
|
.prompt
|
||||||
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
|
.clone()
|
||||||
prompt_str
|
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||||
);
|
|
||||||
|
let prompt_str = match args.which {
|
||||||
|
Which::DeepseekR1Qwen7B => format!("<|User|>{prompt_str}<|Assistant|>"),
|
||||||
|
_ => format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n"),
|
||||||
|
};
|
||||||
print!("formatted instruct prompt: {}", &prompt_str);
|
print!("formatted instruct prompt: {}", &prompt_str);
|
||||||
let tokens = tos
|
let tokens = tos
|
||||||
.tokenizer()
|
.tokenizer()
|
||||||
@ -260,7 +272,13 @@ fn main() -> anyhow::Result<()> {
|
|||||||
print!("{t}");
|
print!("{t}");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
|
|
||||||
|
let eos_token = match args.which {
|
||||||
|
Which::DeepseekR1Qwen7B => "<|end▁of▁sentence|>",
|
||||||
|
_ => "<|im_end|>",
|
||||||
|
};
|
||||||
|
|
||||||
|
let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap();
|
||||||
let start_post_prompt = std::time::Instant::now();
|
let start_post_prompt = std::time::Instant::now();
|
||||||
let mut sampled = 0;
|
let mut sampled = 0;
|
||||||
for index in 0..to_sample {
|
for index in 0..to_sample {
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# candle-quantized-t5
|
# candle-quantized-t5
|
||||||
|
|
||||||
|
Candle implementation for quantizing and running T5 translation models.
|
||||||
|
|
||||||
## Seq2Seq example
|
## Seq2Seq example
|
||||||
|
|
||||||
This example uses a quantized version of the t5 model.
|
This example uses a quantized version of the t5 model.
|
||||||
|
@ -75,6 +75,8 @@ enum Which {
|
|||||||
SmolLM2_360MInstruct,
|
SmolLM2_360MInstruct,
|
||||||
#[value(name = "SmoLM2-1.7B-Instruct")]
|
#[value(name = "SmoLM2-1.7B-Instruct")]
|
||||||
SmolLM2_1BInstruct,
|
SmolLM2_1BInstruct,
|
||||||
|
#[value(name = "deepseekr1-llama8b")]
|
||||||
|
DeepseekR1Llama8b,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Which {
|
impl Which {
|
||||||
@ -94,7 +96,8 @@ impl Which {
|
|||||||
| Self::L8b
|
| Self::L8b
|
||||||
| Self::Phi3
|
| Self::Phi3
|
||||||
| Self::SmolLM2_1BInstruct
|
| Self::SmolLM2_1BInstruct
|
||||||
| Self::SmolLM2_360MInstruct => false,
|
| Self::SmolLM2_360MInstruct
|
||||||
|
| Self::DeepseekR1Llama8b => false,
|
||||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||||
// same way. Starling is a fine tuned version of OpenChat.
|
// same way. Starling is a fine tuned version of OpenChat.
|
||||||
Self::OpenChat35
|
Self::OpenChat35
|
||||||
@ -132,7 +135,8 @@ impl Which {
|
|||||||
| Self::L8b
|
| Self::L8b
|
||||||
| Self::SmolLM2_1BInstruct
|
| Self::SmolLM2_1BInstruct
|
||||||
| Self::SmolLM2_360MInstruct
|
| Self::SmolLM2_360MInstruct
|
||||||
| Self::Phi3 => false,
|
| Self::Phi3
|
||||||
|
| Self::DeepseekR1Llama8b => false,
|
||||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -160,11 +164,41 @@ impl Which {
|
|||||||
| Self::L8b
|
| Self::L8b
|
||||||
| Self::SmolLM2_1BInstruct
|
| Self::SmolLM2_1BInstruct
|
||||||
| Self::SmolLM2_360MInstruct
|
| Self::SmolLM2_360MInstruct
|
||||||
| Self::Phi3 => false,
|
| Self::Phi3
|
||||||
|
| Self::DeepseekR1Llama8b => false,
|
||||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_deepseek(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
Self::L7b
|
||||||
|
| Self::L13b
|
||||||
|
| Self::L70b
|
||||||
|
| Self::L7bChat
|
||||||
|
| Self::L13bChat
|
||||||
|
| Self::L70bChat
|
||||||
|
| Self::L7bCode
|
||||||
|
| Self::L13bCode
|
||||||
|
| Self::L34bCode
|
||||||
|
| Self::Leo7b
|
||||||
|
| Self::Leo13b
|
||||||
|
| Self::Mixtral
|
||||||
|
| Self::MixtralInstruct
|
||||||
|
| Self::Mistral7b
|
||||||
|
| Self::Mistral7bInstruct
|
||||||
|
| Self::Mistral7bInstructV02
|
||||||
|
| Self::Zephyr7bAlpha
|
||||||
|
| Self::Zephyr7bBeta
|
||||||
|
| Self::L8b
|
||||||
|
| Self::SmolLM2_1BInstruct
|
||||||
|
| Self::SmolLM2_360MInstruct
|
||||||
|
| Self::Phi3
|
||||||
|
| Self::OpenChat35
|
||||||
|
| Self::Starling7bAlpha => false,
|
||||||
|
Self::DeepseekR1Llama8b => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
fn tokenizer_repo(&self) -> &'static str {
|
fn tokenizer_repo(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::L7b
|
Self::L7b
|
||||||
@ -191,6 +225,7 @@ impl Which {
|
|||||||
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
||||||
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||||
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||||
|
Self::DeepseekR1Llama8b => "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -363,6 +398,10 @@ impl Args {
|
|||||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
|
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
|
||||||
"smollm2-1.7b-instruct-q4_k_m.gguf",
|
"smollm2-1.7b-instruct-q4_k_m.gguf",
|
||||||
),
|
),
|
||||||
|
Which::DeepseekR1Llama8b => (
|
||||||
|
"unsloth/DeepSeek-R1-Distill-Llama-8B-GGUF",
|
||||||
|
"DeepSeek-R1-Distill-Llama-8B-Q4_K_M.gguf",
|
||||||
|
),
|
||||||
};
|
};
|
||||||
let revision = if self.which == Which::Phi3 {
|
let revision = if self.which == Which::Phi3 {
|
||||||
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
|
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
|
||||||
@ -477,6 +516,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
| Which::L8b
|
| Which::L8b
|
||||||
| Which::SmolLM2_1BInstruct
|
| Which::SmolLM2_1BInstruct
|
||||||
| Which::SmolLM2_360MInstruct
|
| Which::SmolLM2_360MInstruct
|
||||||
|
| Which::DeepseekR1Llama8b
|
||||||
| Which::Phi3 => 1,
|
| Which::Phi3 => 1,
|
||||||
Which::Mixtral
|
Which::Mixtral
|
||||||
| Which::MixtralInstruct
|
| Which::MixtralInstruct
|
||||||
@ -530,6 +570,8 @@ fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
} else if args.which.is_mistral() {
|
} else if args.which.is_mistral() {
|
||||||
format!("[INST] {prompt} [/INST]")
|
format!("[INST] {prompt} [/INST]")
|
||||||
|
} else if args.which.is_deepseek() {
|
||||||
|
format!("<|User|>{prompt}<|Assistant|>")
|
||||||
} else {
|
} else {
|
||||||
prompt
|
prompt
|
||||||
}
|
}
|
||||||
@ -597,6 +639,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let eos_token = match args.which {
|
let eos_token = match args.which {
|
||||||
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
|
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
|
||||||
Which::L8b => "<|end_of_text|>",
|
Which::L8b => "<|end_of_text|>",
|
||||||
|
Which::DeepseekR1Llama8b => "<|end▁of▁sentence|>",
|
||||||
_ => match args.which.is_open_chat() {
|
_ => match args.which.is_open_chat() {
|
||||||
true => "<|end_of_turn|>",
|
true => "<|end_of_turn|>",
|
||||||
false => "</s>",
|
false => "</s>",
|
||||||
|
@ -2,6 +2,11 @@
|
|||||||
|
|
||||||
Reinforcement Learning examples for candle.
|
Reinforcement Learning examples for candle.
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> uv is not currently compatible with pyo3 as of 2025/3/28.
|
||||||
|
|
||||||
|
## System wide python
|
||||||
|
|
||||||
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
This has been tested with `gymnasium` version `0.29.1`. You can install the
|
||||||
Python package with:
|
Python package with:
|
||||||
```bash
|
```bash
|
||||||
|
@ -7,7 +7,7 @@ probabilities for the top-5 classes.
|
|||||||
## Running an example
|
## Running an example
|
||||||
|
|
||||||
```
|
```
|
||||||
$ cargo run --example resnet --release -- --image tiger.jpg
|
$ cargo run --example resnet --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
model built
|
model built
|
||||||
|
@ -10,9 +10,11 @@ If you want you can use the example images from this [pull request][pr], downloa
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# run the image classification task
|
# run the image classification task
|
||||||
cargo run --example segformer classify <path-to-image>
|
cargo run --example segformer classify candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
# run the segmentation task
|
# run the segmentation task
|
||||||
cargo run --example segformer segment <path-to-image>
|
cargo run --example segformer segment candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Example output for classification:
|
Example output for classification:
|
||||||
|
@ -14,8 +14,8 @@ based on [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example segment-anything --release -- \
|
cargo run --example segment-anything --release -- \
|
||||||
--image candle-examples/examples/yolo-v8/assets/bike.jpg
|
--image candle-examples/examples/yolo-v8/assets/bike.jpg \
|
||||||
--use-tiny
|
--use-tiny \
|
||||||
--point 0.6,0.6 --point 0.6,0.55
|
--point 0.6,0.6 --point 0.6,0.55
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ SigLIP is multi-modal text-vision model that improves over CLIP by using a sigmo
|
|||||||
|
|
||||||
### Running an example
|
### Running an example
|
||||||
```
|
```
|
||||||
$ cargo run --features cuda -r --example siglip -
|
$ cargo run --features cuda -r --example siglip
|
||||||
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
|
softmax_image_vec: [2.1912122e-14, 2.3624872e-14, 1.0, 1.0, 2.4787932e-8, 3.2784535e-12]
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,14 @@ This example uses the models available in the hugging face [onnx-community/siler
|
|||||||
|
|
||||||
## Running the example
|
## Running the example
|
||||||
|
|
||||||
|
### using arecord
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
$ arecord -t raw -f S16_LE -r 16000 -c 1 -d 5 - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### using SoX
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ rec -t raw -r 48000 -b 16 -c 1 -e signed-integer - trim 0 5 | sox -t raw -r 48000 -b 16 -c 1 -e signed-integer - -t raw -r 16000 -b 16 -c 1 -e signed-integer - | cargo run --example silero-vad --release --features onnx -- --sample-rate 16000
|
||||||
|
```
|
||||||
|
275
candle-examples/examples/snac/audio_io.rs
Normal file
275
candle-examples/examples/snac/audio_io.rs
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
pub const SAMPLE_RATE: usize = 24_000;
|
||||||
|
|
||||||
|
pub(crate) struct AudioOutputData_ {
|
||||||
|
resampled_data: std::collections::VecDeque<f32>,
|
||||||
|
resampler: rubato::FastFixedIn<f32>,
|
||||||
|
output_buffer: Vec<f32>,
|
||||||
|
input_buffer: Vec<f32>,
|
||||||
|
input_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AudioOutputData_ {
|
||||||
|
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
||||||
|
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
||||||
|
let resampler = rubato::FastFixedIn::new(
|
||||||
|
resample_ratio,
|
||||||
|
f64::max(resample_ratio, 1.0),
|
||||||
|
rubato::PolynomialDegree::Septic,
|
||||||
|
1024,
|
||||||
|
1,
|
||||||
|
)?;
|
||||||
|
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
||||||
|
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
||||||
|
Ok(Self {
|
||||||
|
resampled_data,
|
||||||
|
resampler,
|
||||||
|
input_buffer,
|
||||||
|
output_buffer,
|
||||||
|
input_len: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
use rubato::Resampler;
|
||||||
|
self.output_buffer.fill(0.);
|
||||||
|
self.input_buffer.fill(0.);
|
||||||
|
self.resampler.reset();
|
||||||
|
self.resampled_data.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
||||||
|
let mut data = Vec::with_capacity(self.resampled_data.len());
|
||||||
|
while let Some(elem) = self.resampled_data.pop_back() {
|
||||||
|
data.push(elem);
|
||||||
|
}
|
||||||
|
data
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_empty(&self) -> bool {
|
||||||
|
self.resampled_data.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assumes that the input buffer is large enough.
|
||||||
|
fn push_input_buffer(&mut self, samples: &[f32]) {
|
||||||
|
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
||||||
|
self.input_len += samples.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let mut pos_in = 0;
|
||||||
|
loop {
|
||||||
|
let rem = self.input_buffer.len() - self.input_len;
|
||||||
|
let pos_end = usize::min(pos_in + rem, samples.len());
|
||||||
|
self.push_input_buffer(&samples[pos_in..pos_end]);
|
||||||
|
pos_in = pos_end;
|
||||||
|
if self.input_len < self.input_buffer.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let (_, out_len) = self.resampler.process_into_buffer(
|
||||||
|
&[&self.input_buffer],
|
||||||
|
&mut [&mut self.output_buffer],
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
for &elem in self.output_buffer[..out_len].iter() {
|
||||||
|
self.resampled_data.push_front(elem)
|
||||||
|
}
|
||||||
|
self.input_len = 0;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
||||||
|
|
||||||
|
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||||
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||||
|
|
||||||
|
println!("Setup audio output stream!");
|
||||||
|
let host = cpal::default_host();
|
||||||
|
let device = host
|
||||||
|
.default_output_device()
|
||||||
|
.context("no output device available")?;
|
||||||
|
let mut supported_configs_range = device.supported_output_configs()?;
|
||||||
|
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
||||||
|
// On macOS, it's commonly the case that there are only stereo outputs.
|
||||||
|
None => device
|
||||||
|
.supported_output_configs()?
|
||||||
|
.next()
|
||||||
|
.context("no audio output available")?,
|
||||||
|
Some(config_range) => config_range,
|
||||||
|
};
|
||||||
|
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||||
|
config_range.min_sample_rate(),
|
||||||
|
config_range.max_sample_rate(),
|
||||||
|
);
|
||||||
|
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||||
|
let channels = config.channels as usize;
|
||||||
|
println!(
|
||||||
|
"cpal device: {} {} {config:?}",
|
||||||
|
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||||
|
config.sample_rate.0
|
||||||
|
);
|
||||||
|
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||||
|
SAMPLE_RATE,
|
||||||
|
config.sample_rate.0 as usize,
|
||||||
|
)?));
|
||||||
|
let ad = audio_data.clone();
|
||||||
|
let stream = device.build_output_stream(
|
||||||
|
&config,
|
||||||
|
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||||
|
data.fill(0.);
|
||||||
|
let mut ad = ad.lock().unwrap();
|
||||||
|
let mut last_elem = 0f32;
|
||||||
|
for (idx, elem) in data.iter_mut().enumerate() {
|
||||||
|
if idx % channels == 0 {
|
||||||
|
match ad.resampled_data.pop_back() {
|
||||||
|
None => break,
|
||||||
|
Some(v) => {
|
||||||
|
last_elem = v;
|
||||||
|
*elem = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*elem = last_elem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
move |err| eprintln!("cpal error: {err}"),
|
||||||
|
None, // None=blocking, Some(Duration)=timeout
|
||||||
|
)?;
|
||||||
|
stream.play()?;
|
||||||
|
Ok((stream, audio_data))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||||
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||||
|
|
||||||
|
println!("Setup audio input stream!");
|
||||||
|
let host = cpal::default_host();
|
||||||
|
let device = host
|
||||||
|
.default_input_device()
|
||||||
|
.context("no input device available")?;
|
||||||
|
let mut supported_configs_range = device.supported_input_configs()?;
|
||||||
|
let config_range = supported_configs_range
|
||||||
|
.find(|c| c.channels() == 1)
|
||||||
|
.context("no audio input available")?;
|
||||||
|
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||||
|
config_range.min_sample_rate(),
|
||||||
|
config_range.max_sample_rate(),
|
||||||
|
);
|
||||||
|
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||||
|
println!(
|
||||||
|
"cpal device: {} {} {config:?}",
|
||||||
|
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||||
|
config.sample_rate.0
|
||||||
|
);
|
||||||
|
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||||
|
config.sample_rate.0 as usize,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
)?));
|
||||||
|
let ad = audio_data.clone();
|
||||||
|
let stream = device.build_input_stream(
|
||||||
|
&config,
|
||||||
|
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||||
|
let mut ad = ad.lock().unwrap();
|
||||||
|
if let Err(err) = ad.push_samples(data) {
|
||||||
|
eprintln!("error processing audio input {err:?}")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
move |err| eprintln!("cpal error: {err}"),
|
||||||
|
None, // None=blocking, Some(Duration)=timeout
|
||||||
|
)?;
|
||||||
|
stream.play()?;
|
||||||
|
Ok((stream, audio_data))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||||
|
where
|
||||||
|
T: symphonia::core::sample::Sample,
|
||||||
|
f32: symphonia::core::conv::FromSample<T>,
|
||||||
|
{
|
||||||
|
use symphonia::core::audio::Signal;
|
||||||
|
use symphonia::core::conv::FromSample;
|
||||||
|
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||||
|
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||||
|
|
||||||
|
let src = std::fs::File::open(path)?;
|
||||||
|
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||||
|
let hint = symphonia::core::probe::Hint::new();
|
||||||
|
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||||
|
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||||
|
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||||
|
let mut format = probed.format;
|
||||||
|
let track = format
|
||||||
|
.tracks()
|
||||||
|
.iter()
|
||||||
|
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||||
|
.expect("no supported audio tracks");
|
||||||
|
let mut decoder = symphonia::default::get_codecs()
|
||||||
|
.make(&track.codec_params, &Default::default())
|
||||||
|
.expect("unsupported codec");
|
||||||
|
let track_id = track.id;
|
||||||
|
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||||
|
let mut pcm_data = Vec::new();
|
||||||
|
while let Ok(packet) = format.next_packet() {
|
||||||
|
while !format.metadata().is_latest() {
|
||||||
|
format.metadata().pop();
|
||||||
|
}
|
||||||
|
if packet.track_id() != track_id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match decoder.decode(&packet)? {
|
||||||
|
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||||
|
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((pcm_data, sample_rate))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let mut pcm_out =
|
||||||
|
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||||
|
|
||||||
|
let mut resampler =
|
||||||
|
rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)?;
|
||||||
|
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||||
|
let mut pos_in = 0;
|
||||||
|
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||||
|
let (in_len, out_len) =
|
||||||
|
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
||||||
|
pos_in += in_len;
|
||||||
|
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if pos_in < pcm_in.len() {
|
||||||
|
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
||||||
|
Some(&[&pcm_in[pos_in..]]),
|
||||||
|
&mut output_buffer,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(pcm_out)
|
||||||
|
}
|
197
candle-examples/examples/snac/main.rs
Normal file
197
candle-examples/examples/snac/main.rs
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{DType, IndexOp, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::snac::{Config, Model};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
|
||||||
|
mod audio_io;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Action {
|
||||||
|
AudioToAudio,
|
||||||
|
AudioToCode,
|
||||||
|
CodeToAudio,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "24khz")]
|
||||||
|
S24khz,
|
||||||
|
#[value(name = "32khz")]
|
||||||
|
S32khz,
|
||||||
|
#[value(name = "44khz")]
|
||||||
|
S44khz,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn sample_rate(&self) -> u32 {
|
||||||
|
match self {
|
||||||
|
Which::S24khz => 24000,
|
||||||
|
Which::S32khz => 32000,
|
||||||
|
Which::S44khz => 44000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config_repo(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Which::S24khz => "hubertsiuzdak/snac_24khz",
|
||||||
|
Which::S32khz => "hubertsiuzdak/snac_32khz",
|
||||||
|
Which::S44khz => "hubertsiuzdak/snac_44khz",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model_file(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Which::S24khz => "snac_24khz.safetensors",
|
||||||
|
Which::S32khz => "snac_32khz.safetensors",
|
||||||
|
Which::S44khz => "snac_44khz.safetensors",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// The action to be performed, specifies the format for the input and output data.
|
||||||
|
action: Action,
|
||||||
|
|
||||||
|
/// The input file, either an audio file or some snac tokens stored as safetensors.
|
||||||
|
in_file: String,
|
||||||
|
|
||||||
|
/// The output file, either a wave audio file or some snac tokens stored as safetensors.
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "24khz")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// The model weight file, in safetensor format.
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// The config file, in safetensor format.
|
||||||
|
#[arg(long)]
|
||||||
|
config: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let model_sample_rate = args.which.sample_rate();
|
||||||
|
let config = match args.config {
|
||||||
|
Some(c) => std::path::PathBuf::from(c),
|
||||||
|
None => Api::new()?
|
||||||
|
.model(args.which.config_repo().to_string())
|
||||||
|
.get("config.json")?,
|
||||||
|
};
|
||||||
|
let config: Config = serde_json::from_slice(&std::fs::read(config)?)?;
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("lmz/candle-snac".to_string())
|
||||||
|
.get(args.which.model_file())?,
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
let codes = match args.action {
|
||||||
|
Action::CodeToAudio => {
|
||||||
|
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||||
|
let num_codebooks = model.num_codebooks();
|
||||||
|
(0..num_codebooks)
|
||||||
|
.map(|i| {
|
||||||
|
codes
|
||||||
|
.get(&format!("codes-{i}"))
|
||||||
|
.expect("no codes in input file")
|
||||||
|
.clone()
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
Action::AudioToCode | Action::AudioToAudio => {
|
||||||
|
let pcm = if args.in_file == "-" {
|
||||||
|
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
||||||
|
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
||||||
|
let mut pcms = vec![];
|
||||||
|
let stdin = std::thread::spawn(|| {
|
||||||
|
let mut s = String::new();
|
||||||
|
std::io::stdin().read_line(&mut s)
|
||||||
|
});
|
||||||
|
while !stdin.is_finished() {
|
||||||
|
let input = input_audio.lock().unwrap().take_all();
|
||||||
|
if input.is_empty() {
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
pcms.push(input)
|
||||||
|
}
|
||||||
|
drop(stream);
|
||||||
|
pcms.concat()
|
||||||
|
} else {
|
||||||
|
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
||||||
|
if sample_rate != model_sample_rate {
|
||||||
|
println!("WARNING: snac uses a {model_sample_rate} sample rate, input uses {sample_rate}, resampling...");
|
||||||
|
audio_io::resample(&pcm, sample_rate, model_sample_rate)?
|
||||||
|
} else {
|
||||||
|
pcm
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let pcm_len = pcm.len();
|
||||||
|
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||||
|
println!("input pcm shape: {:?}", pcm.shape());
|
||||||
|
model.encode(&pcm)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for codes in codes.iter() {
|
||||||
|
println!("codes shape: {:?}", codes.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
match args.action {
|
||||||
|
Action::AudioToCode => {
|
||||||
|
let mut tensors = std::collections::HashMap::new();
|
||||||
|
for (i, codes) in codes.iter().enumerate() {
|
||||||
|
tensors.insert(format!("codes-{i}"), codes.clone());
|
||||||
|
}
|
||||||
|
candle::safetensors::save(&tensors, "codes.safetensors")?;
|
||||||
|
}
|
||||||
|
Action::AudioToAudio | Action::CodeToAudio => {
|
||||||
|
let codes = codes.iter().collect::<Vec<_>>();
|
||||||
|
let pcm = model.decode(&codes)?;
|
||||||
|
println!("output pcm shape: {:?}", pcm.shape());
|
||||||
|
let pcm = pcm.i(0)?.i(0)?;
|
||||||
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, model_sample_rate, true)?;
|
||||||
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
|
if args.out_file == "-" {
|
||||||
|
let (stream, ad) = audio_io::setup_output_stream()?;
|
||||||
|
{
|
||||||
|
let mut ad = ad.lock().unwrap();
|
||||||
|
ad.push_samples(&pcm)?;
|
||||||
|
}
|
||||||
|
loop {
|
||||||
|
let ad = ad.lock().unwrap();
|
||||||
|
if ad.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// That's very weird, calling thread::sleep here triggers the stream to stop
|
||||||
|
// playing (the callback doesn't seem to be called anymore).
|
||||||
|
// std::thread::sleep(std::time::Duration::from_millis(100));
|
||||||
|
}
|
||||||
|
drop(stream)
|
||||||
|
} else {
|
||||||
|
let mut output = std::fs::File::create(&args.out_file)?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, model_sample_rate)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
15
candle-examples/examples/starcoder2/README.md
Normal file
15
candle-examples/examples/starcoder2/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# candle-starcoder2
|
||||||
|
|
||||||
|
Candle implementation of Star Coder 2 family of code generation model from [StarCoder 2 and The Stack v2: The Next Generation](https://arxiv.org/pdf/2402.19173).
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example starcoder2 -- --prompt "write a recursive fibonacci function in python "
|
||||||
|
|
||||||
|
> # that returns the nth number in the sequence.
|
||||||
|
>
|
||||||
|
> def fib(n):
|
||||||
|
> if n
|
||||||
|
|
||||||
|
```
|
@ -10,7 +10,7 @@ Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. T
|
|||||||
are downloaded from the hub on the first run.
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?"
|
$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?" --which 1.5b
|
||||||
|
|
||||||
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
|
> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]]
|
||||||
> Tensor[[1, 1024], f32]
|
> Tensor[[1, 1024], f32]
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# candle-t5
|
# candle-t5
|
||||||
|
|
||||||
|
Candle implementations of the T5 family of translation models.
|
||||||
|
|
||||||
## Encoder-decoder example:
|
## Encoder-decoder example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -7,7 +7,7 @@ The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main
|
|||||||
You can run the example with the following command:
|
You can run the example with the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
|
cargo run --example vgg --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which vgg13
|
||||||
```
|
```
|
||||||
|
|
||||||
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).
|
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).
|
||||||
|
@ -7,8 +7,8 @@ probabilities for the top-5 classes.
|
|||||||
|
|
||||||
## Running an example
|
## Running an example
|
||||||
|
|
||||||
```
|
```bash
|
||||||
$ cargo run --example vit --release -- --image tiger.jpg
|
$ cargo run --example vit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
loaded image Tensor[dims 3, 224, 224; f32]
|
||||||
model built
|
model built
|
||||||
|
15
candle-examples/examples/whisper-microphone/README.md
Normal file
15
candle-examples/examples/whisper-microphone/README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# candle-whisper-microphone
|
||||||
|
|
||||||
|
Whisper implementation using microphone as input.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example whisper-microphone --features microphone
|
||||||
|
|
||||||
|
> transcribing audio...
|
||||||
|
> 480256 160083
|
||||||
|
> language_token: None
|
||||||
|
> 0.0s -- 30.0s: Hello, hello, I don't know if this is working, but You know, how long did I make this?
|
||||||
|
> 480256 160085
|
||||||
|
```
|
13
candle-examples/examples/yi/README.md
Normal file
13
candle-examples/examples/yi/README.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# candle-yi
|
||||||
|
|
||||||
|
Candle implentations of the Yi family of bilingual (English, Chinese) LLMs.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example yi -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
|
> python
|
||||||
|
> print("Hello World")
|
||||||
|
>
|
||||||
|
```
|
32
candle-examples/examples/yolo-v3/README.md
Normal file
32
candle-examples/examples/yolo-v3/README.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# candle-yolo-v3:
|
||||||
|
|
||||||
|
Candle implementation of Yolo-V3 for object detection.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example yolo-v3 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
|
|
||||||
|
> generated predictions Tensor[dims 10647, 85; f32]
|
||||||
|
> person: Bbox { xmin: 46.362198, ymin: 72.177, xmax: 135.92522, ymax: 339.8356, confidence: 0.99705493, data: () }
|
||||||
|
> person: Bbox { xmin: 137.25645, ymin: 67.58148, xmax: 216.90437, ymax: 333.80756, confidence: 0.9898516, data: () }
|
||||||
|
> person: Bbox { xmin: 245.7842, ymin: 82.76726, xmax: 316.79053, ymax: 337.21613, confidence: 0.9884322, data: () }
|
||||||
|
> person: Bbox { xmin: 207.52783, ymin: 61.815224, xmax: 266.77884, ymax: 307.92606, confidence: 0.9860648, data: () }
|
||||||
|
> person: Bbox { xmin: 11.457404, ymin: 60.335564, xmax: 34.39357, ymax: 187.7714, confidence: 0.9545012, data: () }
|
||||||
|
> person: Bbox { xmin: 251.88353, ymin: 11.235481, xmax: 286.56607, ymax: 92.54697, confidence: 0.8439807, data: () }
|
||||||
|
> person: Bbox { xmin: -0.44309902, ymin: 55.486923, xmax: 13.160354, ymax: 184.09705, confidence: 0.8266243, data: () }
|
||||||
|
> person: Bbox { xmin: 317.40826, ymin: 55.39501, xmax: 370.6704, ymax: 153.74887, confidence: 0.7327442, data: () }
|
||||||
|
> person: Bbox { xmin: 370.02835, ymin: 66.120224, xmax: 404.22824, ymax: 142.09691, confidence: 0.7265741, data: () }
|
||||||
|
> person: Bbox { xmin: 250.36511, ymin: 57.349842, xmax: 280.06335, ymax: 116.29384, confidence: 0.709422, data: () }
|
||||||
|
> person: Bbox { xmin: 32.573215, ymin: 66.66239, xmax: 50.49056, ymax: 173.42068, confidence: 0.6998766, data: () }
|
||||||
|
> person: Bbox { xmin: 131.72215, ymin: 63.946213, xmax: 166.66151, ymax: 241.52773, confidence: 0.64457536, data: () }
|
||||||
|
> person: Bbox { xmin: 407.42416, ymin: 49.106407, xmax: 415.24307, ymax: 84.7134, confidence: 0.5955802, data: () }
|
||||||
|
> person: Bbox { xmin: 51.650482, ymin: 64.4985, xmax: 67.40904, ymax: 106.952385, confidence: 0.5196007, data: () }
|
||||||
|
> bicycle: Bbox { xmin: 160.10031, ymin: 183.90837, xmax: 200.86832, ymax: 398.609, confidence: 0.9623588, data: () }
|
||||||
|
> bicycle: Bbox { xmin: 66.570915, ymin: 192.56966, xmax: 112.06765, ymax: 369.28497, confidence: 0.9174347, data: () }
|
||||||
|
> bicycle: Bbox { xmin: 258.2856, ymin: 197.04532, xmax: 298.43106, ymax: 364.8627, confidence: 0.6851388, data: () }
|
||||||
|
> bicycle: Bbox { xmin: 214.0034, ymin: 175.76498, xmax: 252.45158, ymax: 356.53818, confidence: 0.67071193, data: () }
|
||||||
|
> motorbike: Bbox { xmin: 318.23938, ymin: 95.22487, xmax: 369.9743, ymax: 213.46263, confidence: 0.96691036, data: () }
|
||||||
|
> motorbike: Bbox { xmin: 367.46417, ymin: 100.07982, xmax: 394.9981, ymax: 174.6545, confidence: 0.9185384, data: () }
|
||||||
|
> writing "candle-examples/examples/yolo-v8/assets/bike.pp.jpg"
|
||||||
|
```
|
@ -133,6 +133,7 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||||||
padding,
|
padding,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
let conv = if bias {
|
let conv = if bias {
|
||||||
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||||
|
@ -92,6 +92,7 @@ impl ConvBlock {
|
|||||||
stride,
|
stride,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,14 +11,17 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.8.4" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.3" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
bindgen_cuda = "0.1.1"
|
bindgen_cuda = "0.1.1"
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
cudnn = ["candle/cudnn"]
|
||||||
|
@ -2,7 +2,6 @@ mod ffi;
|
|||||||
|
|
||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
use candle::cuda_backend::WrapErr;
|
|
||||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
@ -88,6 +87,7 @@ impl FlashAttn {
|
|||||||
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream();
|
||||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||||
if alibi_slopes.dtype() != DType::F32 {
|
if alibi_slopes.dtype() != DType::F32 {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
@ -114,7 +114,9 @@ impl FlashAttn {
|
|||||||
|
|
||||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||||
|
|
||||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
// Dropping the guard here doesn't seem very safe.
|
||||||
|
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||||
|
ptr as *const core::ffi::c_void
|
||||||
} else {
|
} else {
|
||||||
std::ptr::null()
|
std::ptr::null()
|
||||||
};
|
};
|
||||||
@ -139,10 +141,8 @@ impl FlashAttn {
|
|||||||
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||||
|
|
||||||
let elem_count = out_shape.elem_count();
|
let elem_count = out_shape.elem_count();
|
||||||
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(elem_count)? };
|
||||||
let softmax_lse = dev
|
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)?;
|
||||||
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
|
|
||||||
.w()?;
|
|
||||||
|
|
||||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
@ -161,17 +161,17 @@ impl FlashAttn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||||
ffi::run_mha(
|
ffi::run_mha(
|
||||||
q_ptr,
|
q_ptr as *const core::ffi::c_void,
|
||||||
k_ptr,
|
k_ptr as *const core::ffi::c_void,
|
||||||
v_ptr,
|
v_ptr as *const core::ffi::c_void,
|
||||||
dst_ptr,
|
dst_ptr as *const core::ffi::c_void,
|
||||||
softmax_lse_ptr,
|
softmax_lse_ptr as *const core::ffi::c_void,
|
||||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
||||||
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
/* cu_seqlens_q_ptr */ std::ptr::null(),
|
||||||
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
/* cu_seqlens_k_ptr */ std::ptr::null(),
|
||||||
@ -550,6 +550,7 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let batch_size = nseqlens_q - 1;
|
let batch_size = nseqlens_q - 1;
|
||||||
|
|
||||||
|
let stream = dev.cuda_stream();
|
||||||
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
|
||||||
if alibi_slopes.dtype() != DType::F32 {
|
if alibi_slopes.dtype() != DType::F32 {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
@ -576,7 +577,9 @@ impl FlashAttnVarLen {
|
|||||||
|
|
||||||
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
|
||||||
|
|
||||||
*alibi_slopes.device_ptr() as *const core::ffi::c_void
|
// Dropping the guard here doesn't seem very safe.
|
||||||
|
let (ptr, _guard) = alibi_slopes.device_ptr(&stream);
|
||||||
|
ptr as *const core::ffi::c_void
|
||||||
} else {
|
} else {
|
||||||
std::ptr::null()
|
std::ptr::null()
|
||||||
};
|
};
|
||||||
@ -601,8 +604,8 @@ impl FlashAttnVarLen {
|
|||||||
let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128);
|
let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128);
|
||||||
|
|
||||||
let elem_count = out_shape.elem_count();
|
let elem_count = out_shape.elem_count();
|
||||||
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
|
||||||
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
|
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q)?;
|
||||||
|
|
||||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
@ -621,22 +624,22 @@ impl FlashAttnVarLen {
|
|||||||
}
|
}
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
|
let (q_ptr, _guard) = q.device_ptr(&stream);
|
||||||
let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
|
let (k_ptr, _guard) = k.device_ptr(&stream);
|
||||||
let v_ptr = *v.device_ptr() as *const core::ffi::c_void;
|
let (v_ptr, _guard) = v.device_ptr(&stream);
|
||||||
let dst_ptr = *dst.device_ptr() as *const core::ffi::c_void;
|
let (dst_ptr, _guard) = dst.device_ptr(&stream);
|
||||||
let softmax_lse_ptr = *softmax_lse.device_ptr() as *const core::ffi::c_void;
|
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr(&stream);
|
||||||
let seqlens_q_ptr = *seqlens_q.device_ptr() as *const core::ffi::c_int;
|
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
|
||||||
let seqlens_k_ptr = *seqlens_k.device_ptr() as *const core::ffi::c_int;
|
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
|
||||||
ffi::run_mha(
|
ffi::run_mha(
|
||||||
q_ptr,
|
q_ptr as *const core::ffi::c_void,
|
||||||
k_ptr,
|
k_ptr as *const core::ffi::c_void,
|
||||||
v_ptr,
|
v_ptr as *const core::ffi::c_void,
|
||||||
dst_ptr,
|
dst_ptr as *const core::ffi::c_void,
|
||||||
softmax_lse_ptr,
|
softmax_lse_ptr as *const core::ffi::c_void,
|
||||||
/* alibi_slopes_ptr */ alibi_slopes_ptr,
|
/* alibi_slopes_ptr */ alibi_slopes_ptr as *const core::ffi::c_void,
|
||||||
/* cu_seqlens_q_ptr */ seqlens_q_ptr,
|
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
|
||||||
/* cu_seqlens_k_ptr */ seqlens_k_ptr,
|
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
|
||||||
/* q_batch_stride */ 0,
|
/* q_batch_stride */ 0,
|
||||||
/* k_batch_stride */ 0,
|
/* k_batch_stride */ 0,
|
||||||
/* v_batch_stride */ 0,
|
/* v_batch_stride */ 0,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -7,5 +7,5 @@ fn main() {
|
|||||||
let builder = bindgen_cuda::Builder::default();
|
let builder = bindgen_cuda::Builder::default();
|
||||||
println!("cargo:info={builder:?}");
|
println!("cargo:info={builder:?}");
|
||||||
let bindings = builder.build_ptx().unwrap();
|
let bindings = builder.build_ptx().unwrap();
|
||||||
bindings.write("src/lib.rs").unwrap();
|
bindings.write("src/ptx.rs").unwrap();
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ __device__ void conv1d(
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void im2col1d(
|
__device__ void im2col1d(
|
||||||
const size_t dst_numel,
|
const size_t numel,
|
||||||
const size_t l_out,
|
const size_t l_out,
|
||||||
const size_t l_k,
|
const size_t l_k,
|
||||||
const size_t stride,
|
const size_t stride,
|
||||||
@ -63,10 +63,10 @@ __device__ void im2col1d(
|
|||||||
const T *src,
|
const T *src,
|
||||||
T *dst
|
T *dst
|
||||||
) {
|
) {
|
||||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
// dst: (b_size, l_out, c_in, l_k)
|
// dst: (b_size, l_out, c_in, l_k)
|
||||||
// src: (b_size, c_in, l_in)
|
// src: (b_size, c_in, l_in)
|
||||||
if (dst_i >= dst_numel) {
|
if (thread_i >= numel) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const size_t *src_dims = info;
|
const size_t *src_dims = info;
|
||||||
@ -74,26 +74,26 @@ __device__ void im2col1d(
|
|||||||
const size_t c_in = src_dims[1];
|
const size_t c_in = src_dims[1];
|
||||||
const size_t l_in = src_dims[2];
|
const size_t l_in = src_dims[2];
|
||||||
|
|
||||||
const size_t dst_s2 = l_k;
|
const size_t dst_s1 = c_in;
|
||||||
const size_t dst_s1 = c_in * dst_s2;
|
|
||||||
const size_t dst_s0 = l_out * dst_s1;
|
const size_t dst_s0 = l_out * dst_s1;
|
||||||
|
|
||||||
size_t tmp_dst_i = dst_i;
|
size_t tmp_dst_i = thread_i;
|
||||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||||
tmp_dst_i -= b_idx * dst_s0;
|
tmp_dst_i -= b_idx * dst_s0;
|
||||||
const size_t l_idx = tmp_dst_i / dst_s1;
|
const size_t l_idx = tmp_dst_i / dst_s1;
|
||||||
tmp_dst_i -= l_idx * dst_s1;
|
tmp_dst_i -= l_idx * dst_s1;
|
||||||
const size_t c_idx = tmp_dst_i / dst_s2;
|
const size_t c_idx = tmp_dst_i;
|
||||||
tmp_dst_i -= c_idx * dst_s2;
|
for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) {
|
||||||
const size_t l_k_idx = tmp_dst_i;
|
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||||
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
size_t dst_i = thread_i * l_k + l_k_idx;
|
||||||
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||||
dst[dst_i] = static_cast<T>(0);
|
dst[dst_i] = static_cast<T>(0);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
src_l_idx -= padding;
|
src_l_idx -= padding;
|
||||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
||||||
dst[dst_i] = src[src_i];
|
dst[dst_i] = src[src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,11 +1,78 @@
|
|||||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
mod ptx;
|
||||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
|
||||||
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
#[repr(u32)]
|
||||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
pub enum Id {
|
||||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
Affine,
|
||||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
Binary,
|
||||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
Cast,
|
||||||
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
Conv,
|
||||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
Fill,
|
||||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
Indexing,
|
||||||
|
Quantized,
|
||||||
|
Reduce,
|
||||||
|
Sort,
|
||||||
|
Ternary,
|
||||||
|
Unary,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const ALL_IDS: [Id; 11] = [
|
||||||
|
Id::Affine,
|
||||||
|
Id::Binary,
|
||||||
|
Id::Cast,
|
||||||
|
Id::Conv,
|
||||||
|
Id::Fill,
|
||||||
|
Id::Indexing,
|
||||||
|
Id::Quantized,
|
||||||
|
Id::Reduce,
|
||||||
|
Id::Sort,
|
||||||
|
Id::Ternary,
|
||||||
|
Id::Unary,
|
||||||
|
];
|
||||||
|
|
||||||
|
pub struct Module {
|
||||||
|
index: usize,
|
||||||
|
ptx: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module {
|
||||||
|
pub fn index(&self) -> usize {
|
||||||
|
self.index
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn ptx(&self) -> &'static str {
|
||||||
|
self.ptx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const fn module_index(id: Id) -> usize {
|
||||||
|
let mut i = 0;
|
||||||
|
while i < ALL_IDS.len() {
|
||||||
|
if ALL_IDS[i] as u32 == id as u32 {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
panic!("id not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! mdl {
|
||||||
|
($cst:ident, $id:ident) => {
|
||||||
|
pub const $cst: Module = Module {
|
||||||
|
index: module_index(Id::$id),
|
||||||
|
ptx: ptx::$cst,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
mdl!(AFFINE, Affine);
|
||||||
|
mdl!(BINARY, Binary);
|
||||||
|
mdl!(CAST, Cast);
|
||||||
|
mdl!(CONV, Conv);
|
||||||
|
mdl!(FILL, Fill);
|
||||||
|
mdl!(INDEXING, Indexing);
|
||||||
|
mdl!(QUANTIZED, Quantized);
|
||||||
|
mdl!(REDUCE, Reduce);
|
||||||
|
mdl!(SORT, Sort);
|
||||||
|
mdl!(TERNARY, Ternary);
|
||||||
|
mdl!(UNARY, Unary);
|
||||||
|
11
candle-kernels/src/ptx.rs
Normal file
11
candle-kernels/src/ptx.rs
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||||
|
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||||
|
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||||
|
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||||
|
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||||
|
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||||
|
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||||
|
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||||
|
pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx"));
|
||||||
|
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||||
|
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
|
@ -33,6 +33,7 @@ criterion = { workspace = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
|
cudnn = ["candle/cudnn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||||
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
|
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
//! Convolution Layers.
|
//! Convolution Layers.
|
||||||
use crate::BatchNorm;
|
use crate::BatchNorm;
|
||||||
use candle::{Result, Tensor};
|
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub struct Conv1dConfig {
|
pub struct Conv1dConfig {
|
||||||
@ -8,6 +8,7 @@ pub struct Conv1dConfig {
|
|||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub dilation: usize,
|
pub dilation: usize,
|
||||||
pub groups: usize,
|
pub groups: usize,
|
||||||
|
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Conv1dConfig {
|
impl Default for Conv1dConfig {
|
||||||
@ -17,6 +18,7 @@ impl Default for Conv1dConfig {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -52,12 +54,13 @@ impl Conv1d {
|
|||||||
|
|
||||||
impl crate::Module for Conv1d {
|
impl crate::Module for Conv1d {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.conv1d(
|
let x = x.conv1d_with_algo(
|
||||||
&self.weight,
|
&self.weight,
|
||||||
self.config.padding,
|
self.config.padding,
|
||||||
self.config.stride,
|
self.config.stride,
|
||||||
self.config.dilation,
|
self.config.dilation,
|
||||||
self.config.groups,
|
self.config.groups,
|
||||||
|
self.config.cudnn_fwd_algo,
|
||||||
)?;
|
)?;
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
@ -147,6 +150,7 @@ pub struct Conv2dConfig {
|
|||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub dilation: usize,
|
pub dilation: usize,
|
||||||
pub groups: usize,
|
pub groups: usize,
|
||||||
|
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Conv2dConfig {
|
impl Default for Conv2dConfig {
|
||||||
@ -156,6 +160,7 @@ impl Default for Conv2dConfig {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -211,12 +216,13 @@ impl Conv2d {
|
|||||||
|
|
||||||
impl crate::Module for Conv2d {
|
impl crate::Module for Conv2d {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.conv2d(
|
let x = x.conv2d_with_algo(
|
||||||
&self.weight,
|
&self.weight,
|
||||||
self.config.padding,
|
self.config.padding,
|
||||||
self.config.stride,
|
self.config.stride,
|
||||||
self.config.dilation,
|
self.config.dilation,
|
||||||
self.config.groups,
|
self.config.groups,
|
||||||
|
self.config.cudnn_fwd_algo,
|
||||||
)?;
|
)?;
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
|
@ -31,6 +31,7 @@ pub mod ops;
|
|||||||
pub mod optim;
|
pub mod optim;
|
||||||
pub mod rnn;
|
pub mod rnn;
|
||||||
pub mod rotary_emb;
|
pub mod rotary_emb;
|
||||||
|
pub mod sampling;
|
||||||
pub mod sequential;
|
pub mod sequential;
|
||||||
pub mod var_builder;
|
pub mod var_builder;
|
||||||
pub mod var_map;
|
pub mod var_map;
|
||||||
|
@ -41,12 +41,36 @@ impl Linear {
|
|||||||
|
|
||||||
impl super::Module for Linear {
|
impl super::Module for Linear {
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
let w = match *x.dims() {
|
// When possible, we avoid using a broadcasted matmul as it is much slower
|
||||||
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
|
// than the standard matmul for the cuda and cpu backends.
|
||||||
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
let x = match *x.dims() {
|
||||||
_ => self.weight.t()?,
|
[b1, b2, m, k] => {
|
||||||
|
if x.is_contiguous() {
|
||||||
|
let w = self.weight.t()?;
|
||||||
|
x.reshape((b1 * b2 * m, k))?
|
||||||
|
.matmul(&w)?
|
||||||
|
.reshape((b1, b2, m, ()))?
|
||||||
|
} else {
|
||||||
|
let w = self.weight.broadcast_left((b1, b2))?.t()?;
|
||||||
|
x.matmul(&w)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
[bsize, m, k] => {
|
||||||
|
if x.is_contiguous() {
|
||||||
|
let w = self.weight.t()?;
|
||||||
|
x.reshape((bsize * m, k))?
|
||||||
|
.matmul(&w)?
|
||||||
|
.reshape((bsize, m, ()))?
|
||||||
|
} else {
|
||||||
|
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||||
|
x.matmul(&w)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let w = self.weight.t()?;
|
||||||
|
x.matmul(&w)?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let x = x.matmul(&w)?;
|
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
Some(bias) => x.broadcast_add(bias),
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
@ -7,7 +7,7 @@ use candle::{Result, Tensor};
|
|||||||
/// Arguments
|
/// Arguments
|
||||||
///
|
///
|
||||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||||
/// of categories. This is expected to contain log probabilities.
|
/// of categories. This is expected to contain log probabilities.
|
||||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
||||||
///
|
///
|
||||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||||
@ -34,7 +34,7 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
|||||||
/// Arguments
|
/// Arguments
|
||||||
///
|
///
|
||||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||||
/// of categories. This is expected to raw logits.
|
/// of categories. This is expected to raw logits.
|
||||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
|
||||||
///
|
///
|
||||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||||
@ -56,9 +56,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
|||||||
/// Arguments
|
/// Arguments
|
||||||
///
|
///
|
||||||
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
|
||||||
/// of categories. This is expected to raw logits.
|
/// of categories. This is expected to raw logits.
|
||||||
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
|
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
|
||||||
/// of categories.
|
/// of categories.
|
||||||
///
|
///
|
||||||
/// The resulting tensor is a scalar containing the average value over the batch.
|
/// The resulting tensor is a scalar containing the average value over the batch.
|
||||||
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||||
|
@ -90,7 +90,7 @@ impl candle::CustomOp1 for Sigmoid {
|
|||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::SlicePtrOrNull;
|
use candle::cuda_backend::SlicePtrOrNull;
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||||
@ -110,13 +110,17 @@ impl candle::CustomOp1 for Sigmoid {
|
|||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
let out = unsafe { dev.alloc::<T>(el_count)? };
|
||||||
|
|
||||||
let params = (el_count, dims.len(), &ds, src, &out);
|
let mut builder = func.builder();
|
||||||
|
candle::builder_arg!(builder, el_count, dims.len());
|
||||||
|
ds.builder_arg(&mut builder);
|
||||||
|
builder.arg(src);
|
||||||
|
builder.arg(&out);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(out)
|
Ok(out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -340,7 +344,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -367,12 +371,15 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
block_dim: (1, 32, 1),
|
block_dim: (1, 32, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||||
let params = (&src, &dst, n_cols as i32);
|
let mut builder = func.builder();
|
||||||
|
builder.arg(&src);
|
||||||
|
builder.arg(&dst);
|
||||||
|
candle::builder_arg!(builder, n_cols as i32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -516,7 +523,7 @@ impl candle::CustomOp2 for RmsNorm {
|
|||||||
l2: &Layout,
|
l2: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map2, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -552,19 +559,16 @@ impl candle::CustomOp2 for RmsNorm {
|
|||||||
block_dim: (block_size, 1, 1),
|
block_dim: (block_size, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src,
|
builder.arg(&src);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
&alpha,
|
builder.arg(&alpha);
|
||||||
n_cols as i32,
|
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||||
block_size as i32,
|
|
||||||
self.eps,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -751,7 +755,7 @@ impl candle::CustomOp3 for LayerNorm {
|
|||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -793,20 +797,18 @@ impl candle::CustomOp3 for LayerNorm {
|
|||||||
block_dim: (block_size, 1, 1),
|
block_dim: (block_size, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
let func =
|
||||||
|
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src,
|
builder.arg(&src);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
&alpha,
|
builder.arg(&alpha);
|
||||||
&beta,
|
builder.arg(&beta);
|
||||||
n_cols as i32,
|
candle::builder_arg!(builder, n_cols as i32, block_size as i32, self.eps);
|
||||||
block_size as i32,
|
|
||||||
self.eps,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,7 +88,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -117,12 +117,17 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||||
let params = (&src, &cos, &sin, &dst, (b * h) as u32, (t * d) as u32);
|
let mut builder = func.builder();
|
||||||
|
builder.arg(&src);
|
||||||
|
builder.arg(&cos);
|
||||||
|
builder.arg(&sin);
|
||||||
|
builder.arg(&dst);
|
||||||
|
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -333,7 +338,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -362,20 +367,17 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
let (b, h, t, d) = l_src.shape().dims4()?;
|
let (b, h, t, d) = l_src.shape().dims4()?;
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src,
|
builder.arg(&src);
|
||||||
&cos,
|
builder.arg(&cos);
|
||||||
&sin,
|
builder.arg(&sin);
|
||||||
&dst,
|
builder.arg(&dst);
|
||||||
(b * h) as u32,
|
candle::builder_arg!(builder, (b * h) as u32, (t * d) as u32, d as u32);
|
||||||
(t * d) as u32,
|
|
||||||
d as u32,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -587,7 +589,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
l3: &Layout,
|
l3: &Layout,
|
||||||
) -> Result<(candle::CudaStorage, Shape)> {
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
use candle::cuda_backend::cudarc::driver::{
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg,
|
||||||
};
|
};
|
||||||
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
use candle::cuda_backend::{kernel_name, kernels, WrapErr};
|
||||||
use candle::{CudaDevice, WithDType};
|
use candle::{CudaDevice, WithDType};
|
||||||
@ -616,14 +618,17 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
let (b, t, h, d) = l_src.shape().dims4()?;
|
let (b, t, h, d) = l_src.shape().dims4()?;
|
||||||
let el = b * h * t * d;
|
let el = b * h * t * d;
|
||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
let dst = unsafe { dev.alloc::<T>(el)? };
|
||||||
let params = (
|
let mut builder = func.builder();
|
||||||
&src, &cos, &sin, &dst, b as u32, t as u32, h as u32, d as u32,
|
builder.arg(&src);
|
||||||
);
|
builder.arg(&cos);
|
||||||
|
builder.arg(&sin);
|
||||||
|
builder.arg(&dst);
|
||||||
|
candle::builder_arg!(builder, b as u32, t as u32, h as u32, d as u32);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
20
candle-nn/src/sampling.rs
Normal file
20
candle-nn/src/sampling.rs
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
|
/// Sample according to the Gumbel-Softmax distribution.
|
||||||
|
pub fn gumbel_softmax<D: candle::shape::Dim>(
|
||||||
|
logits: &Tensor,
|
||||||
|
temperature: f64,
|
||||||
|
dim: D,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
if temperature <= 0.0 {
|
||||||
|
logits.argmax(dim)
|
||||||
|
} else if temperature == 1.0 {
|
||||||
|
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||||
|
let sampled = (logits - minus_g)?.argmax(dim)?;
|
||||||
|
Ok(sampled)
|
||||||
|
} else {
|
||||||
|
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
||||||
|
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
||||||
|
Ok(sampled)
|
||||||
|
}
|
||||||
|
}
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.8.4"
|
version = "0.9.0-alpha.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.8.4" }
|
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.8.4" }
|
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.3" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -29,6 +29,7 @@ tracing = { workspace = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||||
|
cudnn = ["candle/cudnn", "candle-nn/cudnn"]
|
||||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
|
@ -13,6 +13,8 @@ pub enum Sampling {
|
|||||||
TopK { k: usize, temperature: f64 },
|
TopK { k: usize, temperature: f64 },
|
||||||
TopP { p: f64, temperature: f64 },
|
TopP { p: f64, temperature: f64 },
|
||||||
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
||||||
|
// Note that the rng is not used for the Gumbel-Softmax sampling.
|
||||||
|
GumbelSoftmax { temperature: f64 },
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitsProcessor {
|
pub struct LogitsProcessor {
|
||||||
@ -49,6 +51,11 @@ impl LogitsProcessor {
|
|||||||
Ok(next_token)
|
Ok(next_token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
|
||||||
|
let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
|
||||||
|
sampled.to_vec0::<u32>()
|
||||||
|
}
|
||||||
|
|
||||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||||
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||||
let next_token = distr.sample(&mut self.rng) as u32;
|
let next_token = distr.sample(&mut self.rng) as u32;
|
||||||
@ -127,6 +134,9 @@ impl LogitsProcessor {
|
|||||||
|
|
||||||
let next_token = match &self.sampling {
|
let next_token = match &self.sampling {
|
||||||
Sampling::ArgMax => self.sample_argmax(logits)?,
|
Sampling::ArgMax => self.sample_argmax(logits)?,
|
||||||
|
Sampling::GumbelSoftmax { temperature } => {
|
||||||
|
self.sample_gumbel_softmax(&logits, *temperature)?
|
||||||
|
}
|
||||||
Sampling::All { temperature } => {
|
Sampling::All { temperature } => {
|
||||||
let prs = prs(*temperature)?;
|
let prs = prs(*temperature)?;
|
||||||
self.sample_multinomial(&prs)?
|
self.sample_multinomial(&prs)?
|
||||||
|
@ -504,8 +504,9 @@ impl BertModel {
|
|||||||
Some(attention_mask) => attention_mask.clone(),
|
Some(attention_mask) => attention_mask.clone(),
|
||||||
None => input_ids.ones_like()?,
|
None => input_ids.ones_like()?,
|
||||||
};
|
};
|
||||||
|
let dtype = embedding_output.dtype();
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
||||||
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||||
Ok(sequence_output)
|
Ok(sequence_output)
|
||||||
}
|
}
|
||||||
@ -519,8 +520,11 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
|||||||
};
|
};
|
||||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||||
// torch.finfo(dtype).min
|
// torch.finfo(dtype).min
|
||||||
(attention_mask.ones_like()? - &attention_mask)?
|
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
||||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
&Tensor::try_from(f32::MIN)?
|
||||||
|
.to_device(attention_mask.device())?
|
||||||
|
.to_dtype(dtype)?,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
||||||
|
@ -514,8 +514,9 @@ impl ChineseClipTextTransformer {
|
|||||||
Some(attention_mask) => attention_mask.clone(),
|
Some(attention_mask) => attention_mask.clone(),
|
||||||
None => input_ids.ones_like()?,
|
None => input_ids.ones_like()?,
|
||||||
};
|
};
|
||||||
|
let dtype = embedding_output.dtype();
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
||||||
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
|
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||||
let encoder_output = encoder_outputs.i((.., 0, ..))?;
|
let encoder_output = encoder_outputs.i((.., 0, ..))?;
|
||||||
let pooled_output = match &self.pooler {
|
let pooled_output = match &self.pooler {
|
||||||
@ -535,6 +536,9 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
|||||||
};
|
};
|
||||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||||
// torch.finfo(dtype).min
|
// torch.finfo(dtype).min
|
||||||
(attention_mask.ones_like()? - &attention_mask)?
|
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
||||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
&Tensor::try_from(f32::MIN)?
|
||||||
|
.to_device(attention_mask.device())?
|
||||||
|
.to_dtype(dtype)?,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
533
candle-transformers/src/models/csm.rs
Normal file
533
candle-transformers/src/models/csm.rs
Normal file
@ -0,0 +1,533 @@
|
|||||||
|
//! Implementation of the Conversational Speech Model (CSM) from Sesame
|
||||||
|
//!
|
||||||
|
//! See: [CSM](Conversational Speech Model)
|
||||||
|
//!
|
||||||
|
/// CSM (Conversational Speech Model) is a speech generation model from Sesame that generates RVQ
|
||||||
|
/// audio codes from text and audio inputs. The model architecture employs a Llama backbone and a
|
||||||
|
/// smaller audio decoder that produces Mimi audio codes.
|
||||||
|
///
|
||||||
|
use crate::generation::LogitsProcessor;
|
||||||
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{embedding, linear_b, Embedding, Linear, RmsNorm, VarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum Flavor {
|
||||||
|
#[serde(rename = "llama-1B")]
|
||||||
|
Llama1B,
|
||||||
|
#[serde(rename = "llama-100M")]
|
||||||
|
Llama100M,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
pub audio_num_codebooks: usize,
|
||||||
|
pub audio_vocab_size: usize,
|
||||||
|
pub backbone_flavor: Flavor,
|
||||||
|
pub decoder_flavor: Flavor,
|
||||||
|
pub text_vocab_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LlamaConfig {
|
||||||
|
vocab_size: usize,
|
||||||
|
num_layers: usize,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
embed_dim: usize,
|
||||||
|
max_seq_len: usize,
|
||||||
|
intermediate_dim: usize,
|
||||||
|
norm_eps: f64,
|
||||||
|
rope_base: f32,
|
||||||
|
scale_factor: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlamaConfig {
|
||||||
|
pub fn from_flavor(flavor: Flavor) -> Self {
|
||||||
|
match flavor {
|
||||||
|
Flavor::Llama1B => Self {
|
||||||
|
vocab_size: 128256,
|
||||||
|
num_layers: 16,
|
||||||
|
num_heads: 32,
|
||||||
|
num_kv_heads: 8,
|
||||||
|
embed_dim: 2048,
|
||||||
|
max_seq_len: 2048,
|
||||||
|
intermediate_dim: 8192,
|
||||||
|
norm_eps: 1e-5,
|
||||||
|
rope_base: 500_000.,
|
||||||
|
scale_factor: 32,
|
||||||
|
},
|
||||||
|
Flavor::Llama100M => Self {
|
||||||
|
vocab_size: 128256,
|
||||||
|
num_layers: 4,
|
||||||
|
num_heads: 8,
|
||||||
|
num_kv_heads: 2,
|
||||||
|
embed_dim: 1024,
|
||||||
|
max_seq_len: 2048,
|
||||||
|
intermediate_dim: 8192,
|
||||||
|
norm_eps: 1e-5,
|
||||||
|
rope_base: 500_000.,
|
||||||
|
scale_factor: 32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_default_inv_freq(cfg: &LlamaConfig) -> Vec<f32> {
|
||||||
|
let head_dim = cfg.embed_dim / cfg.num_heads;
|
||||||
|
(0..head_dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_base.powf(i as f32 / head_dim as f32))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(dtype: DType, cfg: &LlamaConfig, dev: &Device) -> Result<Self> {
|
||||||
|
let low_freq_factor = 1.0;
|
||||||
|
let high_freq_factor = 4.0;
|
||||||
|
let original_max_position_embeddings = 8192;
|
||||||
|
let scale_factor = cfg.scale_factor as f32;
|
||||||
|
let theta = {
|
||||||
|
let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
|
||||||
|
let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;
|
||||||
|
|
||||||
|
calculate_default_inv_freq(cfg)
|
||||||
|
.into_iter()
|
||||||
|
.map(|freq| {
|
||||||
|
let wavelen = 2. * std::f32::consts::PI / freq;
|
||||||
|
if wavelen < high_freq_wavelen {
|
||||||
|
freq
|
||||||
|
} else if wavelen > low_freq_wavelen {
|
||||||
|
freq / scale_factor
|
||||||
|
} else {
|
||||||
|
let smooth = (original_max_position_embeddings as f32 / wavelen
|
||||||
|
- low_freq_factor)
|
||||||
|
/ (high_freq_factor - low_freq_factor);
|
||||||
|
(1. - smooth) * freq / scale_factor + smooth * freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
};
|
||||||
|
|
||||||
|
let theta = Tensor::new(theta, dev)?;
|
||||||
|
let idx_theta = Tensor::arange(0, cfg.max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((cfg.max_seq_len, 1))?
|
||||||
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
|
// This is different from the paper, see:
|
||||||
|
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||||
|
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||||
|
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||||
|
Ok(Self { cos, sin })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn rms_norm(hidden_size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
|
||||||
|
let weight = vb.get((hidden_size,), "scale")?;
|
||||||
|
Ok(RmsNorm::new(weight, eps))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
num_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let head_dim = cfg.embed_dim / cfg.num_heads;
|
||||||
|
let kv_dim = cfg.num_kv_heads * head_dim;
|
||||||
|
|
||||||
|
let q_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear_b(cfg.embed_dim, kv_dim, false, vb.pp("v_proj"))?;
|
||||||
|
let o_proj = linear_b(cfg.embed_dim, cfg.embed_dim, false, vb.pp("output_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache: None,
|
||||||
|
num_heads: cfg.num_heads,
|
||||||
|
num_kv_heads: cfg.num_kv_heads,
|
||||||
|
num_kv_groups: cfg.num_heads / cfg.num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b_sz, q_len, _) = xs.dims3()?;
|
||||||
|
|
||||||
|
let query_states = self.q_proj.forward(xs)?;
|
||||||
|
let key_states = self.k_proj.forward(xs)?;
|
||||||
|
let value_states = self.v_proj.forward(xs)?;
|
||||||
|
|
||||||
|
let query_states = query_states
|
||||||
|
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let key_states = key_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let value_states = value_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
|
||||||
|
let (query_states, key_states) =
|
||||||
|
self.rotary_emb
|
||||||
|
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||||
|
|
||||||
|
let (key_states, value_states) = match &self.kv_cache {
|
||||||
|
None => (key_states, value_states),
|
||||||
|
Some((prev_k, prev_v)) => {
|
||||||
|
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||||
|
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||||
|
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||||
|
|
||||||
|
let attn_output = {
|
||||||
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
|
let attn_weights = match attention_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||||
|
};
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
attn_weights.matmul(&value_states)?
|
||||||
|
};
|
||||||
|
attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_sz, q_len, self.num_heads * self.head_dim))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Mlp {
|
||||||
|
w1: Linear,
|
||||||
|
w2: Linear,
|
||||||
|
w3: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let w1 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w1"))?;
|
||||||
|
let w2 = linear_b(cfg.intermediate_dim, cfg.embed_dim, false, vb.pp("w2"))?;
|
||||||
|
let w3 = linear_b(cfg.embed_dim, cfg.intermediate_dim, false, vb.pp("w3"))?;
|
||||||
|
Ok(Self { w1, w2, w3 })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Mlp {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let lhs = xs.apply(&self.w1)?.silu()?;
|
||||||
|
let rhs = xs.apply(&self.w3)?;
|
||||||
|
(lhs * rhs)?.apply(&self.w2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Layer {
|
||||||
|
mlp_norm: RmsNorm,
|
||||||
|
sa_norm: RmsNorm,
|
||||||
|
attn: Attention,
|
||||||
|
mlp: Mlp,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Layer {
|
||||||
|
fn new(cfg: &LlamaConfig, rotary_emb: Arc<RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let mlp_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("mlp_norm"))?;
|
||||||
|
let sa_norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("sa_norm"))?;
|
||||||
|
let attn = Attention::new(cfg, rotary_emb, vb.pp("attn"))?;
|
||||||
|
let mlp = Mlp::new(cfg, vb.pp("mlp"))?;
|
||||||
|
Ok(Self {
|
||||||
|
mlp_norm,
|
||||||
|
sa_norm,
|
||||||
|
attn,
|
||||||
|
mlp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = self.sa_norm.forward(xs)?;
|
||||||
|
let xs = self.attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?;
|
||||||
|
residual + xs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.attn.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct LlamaModel {
|
||||||
|
layers: Vec<Layer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlamaModel {
|
||||||
|
pub fn new(cfg: &LlamaConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_layers);
|
||||||
|
let vb_l = vb.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.num_layers {
|
||||||
|
let layer = Layer::new(cfg, rotary_emb.clone(), vb_l.pp(layer_idx))?;
|
||||||
|
layers.push(layer);
|
||||||
|
}
|
||||||
|
let norm = rms_norm(cfg.embed_dim, cfg.norm_eps, vb.pp("norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
layer.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_decoder_attention_mask(
|
||||||
|
&self,
|
||||||
|
tgt_len: usize,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..tgt_len)
|
||||||
|
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||||
|
let mask = if seqlen_offset > 0 {
|
||||||
|
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||||
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
|
.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
|
let (_b_size, seq_len, _embed_dim) = xs.dims3()?;
|
||||||
|
let attention_mask = if seq_len <= 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
|
||||||
|
Some(mask)
|
||||||
|
};
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?;
|
||||||
|
}
|
||||||
|
let ys = xs.narrow(1, seq_len - 1, 1)?.apply(&self.norm)?;
|
||||||
|
Ok(ys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
backbone: LlamaModel,
|
||||||
|
decoder: LlamaModel,
|
||||||
|
codebook0_head: Linear,
|
||||||
|
audio_embeddings: Embedding,
|
||||||
|
text_embeddings: Embedding,
|
||||||
|
projection: Linear,
|
||||||
|
audio_head: Tensor,
|
||||||
|
config: Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let backbone_cfg = LlamaConfig::from_flavor(cfg.backbone_flavor);
|
||||||
|
let backbone = LlamaModel::new(&backbone_cfg, vb.pp("backbone"))?;
|
||||||
|
let decoder_cfg = LlamaConfig::from_flavor(cfg.decoder_flavor);
|
||||||
|
let decoder = LlamaModel::new(&decoder_cfg, vb.pp("decoder"))?;
|
||||||
|
let backbone_dim = backbone_cfg.embed_dim;
|
||||||
|
let decoder_dim = decoder_cfg.embed_dim;
|
||||||
|
let audio_embeddings = embedding(
|
||||||
|
cfg.audio_vocab_size * cfg.audio_num_codebooks,
|
||||||
|
backbone_dim,
|
||||||
|
vb.pp("audio_embeddings"),
|
||||||
|
)?;
|
||||||
|
let text_embeddings =
|
||||||
|
embedding(cfg.text_vocab_size, backbone_dim, vb.pp("text_embeddings"))?;
|
||||||
|
let projection = linear_b(backbone_dim, decoder_dim, false, vb.pp("projection"))?;
|
||||||
|
let codebook0_head = linear_b(
|
||||||
|
backbone_dim,
|
||||||
|
cfg.audio_vocab_size,
|
||||||
|
false,
|
||||||
|
vb.pp("codebook0_head"),
|
||||||
|
)?;
|
||||||
|
let audio_head = vb.get(
|
||||||
|
(
|
||||||
|
cfg.audio_num_codebooks - 1,
|
||||||
|
decoder_dim,
|
||||||
|
cfg.audio_vocab_size,
|
||||||
|
),
|
||||||
|
"audio_head",
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
backbone,
|
||||||
|
decoder,
|
||||||
|
codebook0_head,
|
||||||
|
audio_embeddings,
|
||||||
|
text_embeddings,
|
||||||
|
projection,
|
||||||
|
audio_head,
|
||||||
|
config: cfg.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.backbone.clear_kv_cache();
|
||||||
|
self.decoder.clear_kv_cache();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_frame(
|
||||||
|
&mut self,
|
||||||
|
tokens: &Tensor,
|
||||||
|
tokens_mask: &Tensor,
|
||||||
|
input_pos: usize,
|
||||||
|
lp: &mut LogitsProcessor,
|
||||||
|
) -> Result<Vec<u32>> {
|
||||||
|
let (b_sz, seq_len, _cb_plus_one) = tokens.dims3()?;
|
||||||
|
let audio_tokens = tokens.narrow(2, 0, self.config.audio_num_codebooks)?;
|
||||||
|
let text_tokens = tokens.narrow(2, self.config.audio_num_codebooks, 1)?;
|
||||||
|
let text_embeds = self.text_embeddings.forward(&text_tokens)?;
|
||||||
|
let arange = (Tensor::arange(
|
||||||
|
0u32,
|
||||||
|
self.config.audio_num_codebooks as u32,
|
||||||
|
&self.decoder.device,
|
||||||
|
)? * self.config.audio_vocab_size as f64)?;
|
||||||
|
let audio_tokens = audio_tokens.broadcast_add(&arange.reshape((1, 1, ()))?)?;
|
||||||
|
let audio_embeds = self.audio_embeddings.forward(&audio_tokens)?.reshape((
|
||||||
|
b_sz,
|
||||||
|
seq_len,
|
||||||
|
self.config.audio_num_codebooks,
|
||||||
|
(),
|
||||||
|
))?;
|
||||||
|
let embeds = Tensor::cat(&[&audio_embeds, &text_embeds], D::Minus2)?;
|
||||||
|
let embeds = embeds.broadcast_mul(
|
||||||
|
&tokens_mask
|
||||||
|
.to_dtype(self.backbone.dtype)?
|
||||||
|
.unsqueeze(D::Minus1)?,
|
||||||
|
)?;
|
||||||
|
let embeds = embeds.sum(2)?;
|
||||||
|
let h = self.backbone.forward(&embeds, input_pos)?;
|
||||||
|
let c0_logits = h.apply(&self.codebook0_head)?;
|
||||||
|
let c0_sample = lp.sample(&c0_logits.i((0, 0))?)?;
|
||||||
|
let mut all_samples = vec![c0_sample];
|
||||||
|
let c0_sample = Tensor::from_slice(&[c0_sample], (1, 1), &self.decoder.device)?;
|
||||||
|
let c0_embed = self.audio_embeddings.forward(&c0_sample)?;
|
||||||
|
let mut curr_h = Tensor::cat(&[h, c0_embed], 1)?;
|
||||||
|
|
||||||
|
self.decoder.clear_kv_cache();
|
||||||
|
let mut decoder_pos = 0;
|
||||||
|
for i in 1..self.config.audio_num_codebooks {
|
||||||
|
let proj_h = curr_h.apply(&self.projection)?;
|
||||||
|
let decoder_h = self.decoder.forward(&proj_h, decoder_pos)?;
|
||||||
|
decoder_pos += curr_h.dim(1)?;
|
||||||
|
let ci_logits = decoder_h.broadcast_matmul(&self.audio_head.get(i - 1)?)?;
|
||||||
|
let ci_sample = lp.sample(&ci_logits.i((0, 0))?)?;
|
||||||
|
all_samples.push(ci_sample);
|
||||||
|
let ci_sample = Tensor::from_slice(
|
||||||
|
&[ci_sample + (i * self.config.audio_vocab_size) as u32],
|
||||||
|
(1, 1),
|
||||||
|
&self.decoder.device,
|
||||||
|
)?;
|
||||||
|
let ci_embed = self.audio_embeddings.forward(&ci_sample)?;
|
||||||
|
curr_h = ci_embed
|
||||||
|
}
|
||||||
|
Ok(all_samples)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn audio_tokens_and_mask(&self, mut frame: Vec<u32>) -> Result<(Tensor, Tensor)> {
|
||||||
|
let cb = self.config.audio_num_codebooks;
|
||||||
|
let device = &self.backbone.device;
|
||||||
|
let mut mask = vec![1u8; cb];
|
||||||
|
mask.push(0);
|
||||||
|
let mask = Tensor::from_vec(mask, (1, 1, cb + 1), device)?;
|
||||||
|
|
||||||
|
frame.push(0);
|
||||||
|
let tokens = Tensor::from_vec(frame, (1, 1, cb + 1), device)?;
|
||||||
|
Ok((tokens, mask))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn text_tokens_and_mask(&self, ids: &[u32]) -> Result<(Tensor, Tensor)> {
|
||||||
|
let cb = self.config.audio_num_codebooks;
|
||||||
|
let device = &self.backbone.device;
|
||||||
|
let mut tokens = vec![];
|
||||||
|
let mut mask = vec![];
|
||||||
|
for &v in ids.iter() {
|
||||||
|
let mut token = vec![0; cb];
|
||||||
|
token.push(v);
|
||||||
|
let token = Tensor::from_vec(token, (1, 1, cb + 1), device)?;
|
||||||
|
tokens.push(token);
|
||||||
|
let mut m = vec![0u8; cb];
|
||||||
|
m.push(1);
|
||||||
|
let m = Tensor::from_vec(m, (1, 1, cb + 1), device)?;
|
||||||
|
mask.push(m);
|
||||||
|
}
|
||||||
|
let tokens = Tensor::cat(&tokens, 1)?;
|
||||||
|
let mask = Tensor::cat(&mask, 1)?;
|
||||||
|
Ok((tokens, mask))
|
||||||
|
}
|
||||||
|
}
|
@ -104,7 +104,7 @@ impl EncoderBlock {
|
|||||||
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
||||||
let cfg1 = Conv1dConfig {
|
let cfg1 = Conv1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: (stride + 1) / 2,
|
padding: stride.div_ceil(2),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
||||||
@ -196,7 +196,7 @@ impl DecoderBlock {
|
|||||||
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
||||||
let cfg = ConvTranspose1dConfig {
|
let cfg = ConvTranspose1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: (stride + 1) / 2,
|
padding: stride.div_ceil(2),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
||||||
@ -330,6 +330,7 @@ impl ResidualVectorQuantizer {
|
|||||||
Ok(Self { quantizers })
|
Ok(Self { quantizers })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::wrong_self_convention)]
|
||||||
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
|
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
|
||||||
let mut sum = None;
|
let mut sum = None;
|
||||||
for (idx, quantizer) in self.quantizers.iter().enumerate() {
|
for (idx, quantizer) in self.quantizers.iter().enumerate() {
|
||||||
|
@ -124,6 +124,7 @@ impl ResidualConvUnit {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
let conv1 = conv2d(
|
let conv1 = conv2d(
|
||||||
conf.num_features,
|
conf.num_features,
|
||||||
@ -208,6 +209,7 @@ impl FeatureFusionBlock {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
let output_conv = conv2d(
|
let output_conv = conv2d(
|
||||||
conf.num_features,
|
conf.num_features,
|
||||||
@ -258,6 +260,7 @@ impl Scratch {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let layer1_rn = conv2d_no_bias(
|
let layer1_rn = conv2d_no_bias(
|
||||||
@ -319,6 +322,7 @@ impl Scratch {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
let output_conv1 = conv2d(
|
let output_conv1 = conv2d(
|
||||||
conf.num_features,
|
conf.num_features,
|
||||||
@ -425,6 +429,7 @@ impl DPTHead {
|
|||||||
stride: 2,
|
stride: 2,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
},
|
},
|
||||||
vb.pp("resize_layers").pp("3"),
|
vb.pp("resize_layers").pp("3"),
|
||||||
)?),
|
)?),
|
||||||
|
@ -19,7 +19,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum HiddenAct {
|
pub enum HiddenAct {
|
||||||
Gelu,
|
Gelu,
|
||||||
Relu,
|
Relu,
|
||||||
}
|
}
|
||||||
@ -49,22 +49,22 @@ impl Module for HiddenActLayer {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum PositionEmbeddingType {
|
pub enum PositionEmbeddingType {
|
||||||
#[default]
|
#[default]
|
||||||
Absolute,
|
Absolute,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
dim: usize,
|
pub dim: usize,
|
||||||
n_layers: usize,
|
n_layers: usize,
|
||||||
n_heads: usize,
|
n_heads: usize,
|
||||||
hidden_dim: usize,
|
hidden_dim: usize,
|
||||||
activation: HiddenAct,
|
activation: HiddenAct,
|
||||||
max_position_embeddings: usize,
|
max_position_embeddings: usize,
|
||||||
initializer_range: f64,
|
initializer_range: f64,
|
||||||
pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
position_embedding_type: PositionEmbeddingType,
|
position_embedding_type: PositionEmbeddingType,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@ -345,3 +345,107 @@ impl DistilBertModel {
|
|||||||
Ok(sequence_output)
|
Ok(sequence_output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DistilBertPredictionHeadTransform {
|
||||||
|
dense: Linear,
|
||||||
|
activation: HiddenActLayer,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistilBertPredictionHeadTransform {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let dense = linear(config.dim, config.dim, vb.pp("vocab_transform"))?;
|
||||||
|
let activation = HiddenActLayer::new(config.activation);
|
||||||
|
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("vocab_layer_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
activation,
|
||||||
|
layer_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for DistilBertPredictionHeadTransform {
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
let hidden_states = self
|
||||||
|
.activation
|
||||||
|
.forward(&self.dense.forward(hidden_states)?)?;
|
||||||
|
self.layer_norm.forward(&hidden_states)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
|
||||||
|
pub struct DistilBertLMPredictionHead {
|
||||||
|
transform: DistilBertPredictionHeadTransform,
|
||||||
|
decoder: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistilBertLMPredictionHead {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let transform = DistilBertPredictionHeadTransform::load(vb.clone(), config)?;
|
||||||
|
|
||||||
|
// distil_bert_uncased uses the word embeddings for the vocab projector weight, but has a seperate vocab_projector bias
|
||||||
|
let vocab_projector_weight_vb = vb.pp("distilbert.embeddings.word_embeddings");
|
||||||
|
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
||||||
|
let ws = vocab_projector_weight_vb.get_with_hints(
|
||||||
|
(config.vocab_size, config.dim),
|
||||||
|
"weight",
|
||||||
|
init_ws,
|
||||||
|
)?;
|
||||||
|
let bound = 1. / (config.dim as f64).sqrt();
|
||||||
|
let init_bs = candle_nn::Init::Uniform {
|
||||||
|
lo: -bound,
|
||||||
|
up: bound,
|
||||||
|
};
|
||||||
|
|
||||||
|
let vocab_projector_bias_vb = vb.pp("vocab_projector");
|
||||||
|
let bs = vocab_projector_bias_vb.get_with_hints(config.vocab_size, "bias", init_bs)?;
|
||||||
|
|
||||||
|
let decoder = Linear::from_weights(ws, Some(bs));
|
||||||
|
|
||||||
|
Ok(Self { transform, decoder })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for DistilBertLMPredictionHead {
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
self.decoder
|
||||||
|
.forward(&self.transform.forward(hidden_states)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
|
||||||
|
pub struct DistilBertOnlyMLMHead {
|
||||||
|
predictions: DistilBertLMPredictionHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistilBertOnlyMLMHead {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let predictions = DistilBertLMPredictionHead::load(vb.clone(), config)?;
|
||||||
|
Ok(Self { predictions })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for DistilBertOnlyMLMHead {
|
||||||
|
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
|
||||||
|
self.predictions.forward(sequence_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DistilBertForMaskedLM {
|
||||||
|
pub bert: DistilBertModel,
|
||||||
|
cls: DistilBertOnlyMLMHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistilBertForMaskedLM {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let bert = DistilBertModel::load(vb.pp("distilbert"), config)?;
|
||||||
|
let cls = DistilBertOnlyMLMHead::load(vb.clone(), config)?;
|
||||||
|
Ok(Self { bert, cls })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let sequence_output = self.bert.forward(input_ids, attention_mask)?;
|
||||||
|
self.cls.forward(&sequence_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -141,6 +141,20 @@ pub fn conv1d_weight_norm(
|
|||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn conv1d_weight_norm_no_bias(
|
||||||
|
in_c: usize,
|
||||||
|
out_c: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: candle_nn::Conv1dConfig,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Conv1d> {
|
||||||
|
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||||
|
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||||
|
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||||
|
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||||
|
Ok(Conv1d::new(weight, None, config))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn conv_transpose1d_weight_norm(
|
pub fn conv_transpose1d_weight_norm(
|
||||||
in_c: usize,
|
in_c: usize,
|
||||||
out_c: usize,
|
out_c: usize,
|
||||||
@ -454,6 +468,7 @@ impl EncodecConv1d {
|
|||||||
stride,
|
stride,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
|
cudnn_fwd_algo: None,
|
||||||
},
|
},
|
||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
|
@ -6,8 +6,8 @@ pub fn get_noise(
|
|||||||
width: usize,
|
width: usize,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let height = (height + 15) / 16 * 2;
|
let height = height.div_ceil(16) * 2;
|
||||||
let width = (width + 15) / 16 * 2;
|
let width = width.div_ceil(16) * 2;
|
||||||
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,8 +84,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
|
|||||||
|
|
||||||
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
||||||
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
||||||
let height = (height + 15) / 16;
|
let height = height.div_ceil(16);
|
||||||
let width = (width + 15) / 16;
|
let width = width.div_ceil(16);
|
||||||
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
||||||
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
||||||
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
||||||
|
@ -27,7 +27,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
(self.d_model + 15) / 16
|
self.d_model.div_ceil(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_inner(&self) -> usize {
|
fn d_inner(&self) -> usize {
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user