mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
9 Commits
0.9.0-alph
...
fix-1.86
Author | SHA1 | Date | |
---|---|---|---|
5341bf4cd5 | |||
8977c31b6d | |||
3be12b8b50 | |||
825119ac4b | |||
e319cd78d9 | |||
3fb67e0c2c | |||
d72c44705c | |||
2203f0e3c9 | |||
01e895c1aa |
40
.github/workflows/book-cd.yml
vendored
Normal file
40
.github/workflows/book-cd.yml
vendored
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
name: Deploy Rust book
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write # To push a branch
|
||||||
|
pull-requests: write # To create a PR from that branch
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Install latest mdbook
|
||||||
|
run: |
|
||||||
|
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
||||||
|
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
||||||
|
mkdir mdbook
|
||||||
|
curl -sSL $url | tar -xz --directory=./mdbook
|
||||||
|
echo `pwd`/mdbook >> $GITHUB_PATH
|
||||||
|
- name: Deploy GitHub Pages
|
||||||
|
run: |
|
||||||
|
# This assumes your book is in the root of your repository.
|
||||||
|
# Just add a `cd` here if you need to change to another directory.
|
||||||
|
cd candle-book
|
||||||
|
mdbook build
|
||||||
|
git worktree add gh-pages
|
||||||
|
git config user.name "Deploy from CI"
|
||||||
|
git config user.email ""
|
||||||
|
cd gh-pages
|
||||||
|
# Delete the ref to avoid keeping history.
|
||||||
|
git update-ref -d refs/heads/gh-pages
|
||||||
|
rm -rf *
|
||||||
|
mv ../book/* .
|
||||||
|
git add .
|
||||||
|
git commit -m "Deploy $GITHUB_SHA to gh-pages"
|
||||||
|
git push --force --set-upstream origin gh-pages
|
29
.github/workflows/book.yml
vendored
Normal file
29
.github/workflows/book.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
name: CI
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
name: Test candle-book
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write # To push a branch
|
||||||
|
pull-requests: write # To create a PR from that branch
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@master
|
||||||
|
- name: Install Rust
|
||||||
|
run: |
|
||||||
|
rustup set profile minimal
|
||||||
|
rustup toolchain install stable
|
||||||
|
rustup default stable
|
||||||
|
- name: Install latest mdbook
|
||||||
|
run: |
|
||||||
|
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
|
||||||
|
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
|
||||||
|
mkdir bin
|
||||||
|
curl -sSL $url | tar -xz --directory=bin
|
||||||
|
echo "$(pwd)/bin" >> $GITHUB_PATH
|
||||||
|
- name: Run tests
|
||||||
|
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/
|
||||||
|
|
||||||
|
|
28
Cargo.toml
28
Cargo.toml
@ -3,6 +3,7 @@ members = [
|
|||||||
"candle-core",
|
"candle-core",
|
||||||
"candle-datasets",
|
"candle-datasets",
|
||||||
"candle-examples",
|
"candle-examples",
|
||||||
|
"candle-book",
|
||||||
"candle-nn",
|
"candle-nn",
|
||||||
"candle-pyo3",
|
"candle-pyo3",
|
||||||
"candle-transformers",
|
"candle-transformers",
|
||||||
@ -11,7 +12,6 @@ members = [
|
|||||||
"tensor-tools",
|
"tensor-tools",
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = [
|
||||||
"candle-book",
|
|
||||||
"candle-flash-attn",
|
"candle-flash-attn",
|
||||||
"candle-kernels",
|
"candle-kernels",
|
||||||
"candle-metal-kernels",
|
"candle-metal-kernels",
|
||||||
@ -20,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.3" }
|
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.3" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.3" }
|
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.3" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.3" }
|
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.3" }
|
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.3" }
|
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.4.1"
|
hf-hub = "0.4.1"
|
||||||
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
ug = "0.3.1"
|
ug = "0.2.0"
|
||||||
ug-cuda = "0.3.1"
|
ug-cuda = "0.2.0"
|
||||||
ug-metal = "0.3.1"
|
ug-metal = "0.2.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "1.1.1", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
@ -56,7 +56,3 @@ harness = false
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "metal_basics"
|
name = "metal_basics"
|
||||||
required-features = ["metal"]
|
required-features = ["metal"]
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "cuda_basics"
|
|
||||||
required-features = ["cuda"]
|
|
||||||
|
@ -6,18 +6,28 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
let device = Device::new_cuda(0)?;
|
||||||
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
|
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||||
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||||
|
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
|
drop(_x1);
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
|
device.synchronize()?;
|
||||||
|
println!("fp32: {:?}", start_time.elapsed());
|
||||||
|
drop(_x1);
|
||||||
|
candle_core::cuda::set_gemm_reduced_precision_f32(true);
|
||||||
|
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
|
drop(_x1);
|
||||||
|
let start_time = std::time::Instant::now();
|
||||||
|
let _x1 = x.matmul(&x)?;
|
||||||
|
device.synchronize()?;
|
||||||
|
println!("tf32: {:?}", start_time.elapsed());
|
||||||
drop(_x1);
|
drop(_x1);
|
||||||
for _ in 0..20 {
|
|
||||||
let start_time = std::time::Instant::now();
|
|
||||||
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
|
|
||||||
device.synchronize()?;
|
|
||||||
println!("conv1d: {:?}", start_time.elapsed());
|
|
||||||
}
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,6 @@ pub struct ParamsConv1D {
|
|||||||
pub(crate) padding: usize,
|
pub(crate) padding: usize,
|
||||||
pub(crate) stride: usize,
|
pub(crate) stride: usize,
|
||||||
pub(crate) dilation: usize,
|
pub(crate) dilation: usize,
|
||||||
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ParamsConv1D {
|
impl ParamsConv1D {
|
||||||
@ -55,7 +54,7 @@ impl ParamsConvTranspose1D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub enum CudnnFwdAlgo {
|
pub enum CudnnFwdAlgo {
|
||||||
ImplicitGemm,
|
ImplicitGemm,
|
||||||
ImplicitPrecompGemm,
|
ImplicitPrecompGemm,
|
||||||
@ -152,19 +151,6 @@ impl Tensor {
|
|||||||
stride: usize,
|
stride: usize,
|
||||||
dilation: usize,
|
dilation: usize,
|
||||||
groups: usize,
|
groups: usize,
|
||||||
) -> Result<Self> {
|
|
||||||
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a 1D convolution over the input tensor.
|
|
||||||
pub fn conv1d_with_algo(
|
|
||||||
&self,
|
|
||||||
kernel: &Self,
|
|
||||||
padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
groups: usize,
|
|
||||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||||
let (b_size, c_in, l_in) = self.dims3()?;
|
let (b_size, c_in, l_in) = self.dims3()?;
|
||||||
@ -188,7 +174,6 @@ impl Tensor {
|
|||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
cudnn_fwd_algo,
|
|
||||||
};
|
};
|
||||||
if groups == 1 {
|
if groups == 1 {
|
||||||
self.conv1d_single_group(kernel, ¶ms)
|
self.conv1d_single_group(kernel, ¶ms)
|
||||||
@ -293,18 +278,6 @@ impl Tensor {
|
|||||||
stride: usize,
|
stride: usize,
|
||||||
dilation: usize,
|
dilation: usize,
|
||||||
groups: usize,
|
groups: usize,
|
||||||
) -> Result<Self> {
|
|
||||||
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv2d_with_algo(
|
|
||||||
&self,
|
|
||||||
kernel: &Self,
|
|
||||||
padding: usize,
|
|
||||||
stride: usize,
|
|
||||||
dilation: usize,
|
|
||||||
groups: usize,
|
|
||||||
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||||
@ -324,7 +297,7 @@ impl Tensor {
|
|||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
cudnn_fwd_algo,
|
cudnn_fwd_algo: None,
|
||||||
};
|
};
|
||||||
if groups == 1 {
|
if groups == 1 {
|
||||||
self.conv2d_single_group(kernel, ¶ms)
|
self.conv2d_single_group(kernel, ¶ms)
|
||||||
|
@ -1289,15 +1289,6 @@ impl Map2 for MatMul {
|
|||||||
} else {
|
} else {
|
||||||
Parallelism::None
|
Parallelism::None
|
||||||
};
|
};
|
||||||
let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
|
|
||||||
// a_skip and c_skip should be updated but step is always 0 so
|
|
||||||
// it wouldn't matter.
|
|
||||||
(1, b * m, n, k)
|
|
||||||
} else if a_skip == 0 && b_skip == n * k {
|
|
||||||
(1, m, b * n, k)
|
|
||||||
} else {
|
|
||||||
(b, m, n, k)
|
|
||||||
};
|
|
||||||
for step in 0..b {
|
for step in 0..b {
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
|
@ -122,104 +122,3 @@ pub(crate) fn launch_conv2d<
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn launch_conv1d<
|
|
||||||
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
|
|
||||||
Y: cudarc::cudnn::CudnnDataType,
|
|
||||||
>(
|
|
||||||
src: &CudaView<T>,
|
|
||||||
src_l: &crate::Layout,
|
|
||||||
filter: &CudaView<T>,
|
|
||||||
dst: &mut CudaSlice<T>,
|
|
||||||
params: &crate::conv::ParamsConv1D,
|
|
||||||
dev: &crate::cuda_backend::CudaDevice,
|
|
||||||
) -> crate::Result<()> {
|
|
||||||
use crate::conv::CudnnFwdAlgo as CandleAlgo;
|
|
||||||
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
|
|
||||||
|
|
||||||
let device_id = dev.id();
|
|
||||||
let cudnn = CUDNN.with(|cudnn| {
|
|
||||||
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
|
|
||||||
return Ok(cudnn.clone());
|
|
||||||
}
|
|
||||||
let c = Cudnn::new(dev.cuda_stream());
|
|
||||||
if let Ok(c) = &c {
|
|
||||||
cudnn.borrow_mut().insert(device_id, c.clone());
|
|
||||||
}
|
|
||||||
c
|
|
||||||
})?;
|
|
||||||
let conv = cudnn.create_conv2d::<Y>(
|
|
||||||
/* pad */ [params.padding as i32, 0],
|
|
||||||
/* stride */ [params.stride as i32, 1],
|
|
||||||
/* dilation */ [params.dilation as i32, 1],
|
|
||||||
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
|
|
||||||
)?;
|
|
||||||
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
|
|
||||||
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
|
|
||||||
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
|
|
||||||
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
|
|
||||||
// > to 1.
|
|
||||||
let x_shape = [
|
|
||||||
params.b_size as i32,
|
|
||||||
params.c_in as i32,
|
|
||||||
params.l_in as i32,
|
|
||||||
1,
|
|
||||||
];
|
|
||||||
// Note that `src` already starts at the proper offset.
|
|
||||||
let x = if src_l.is_contiguous() {
|
|
||||||
cudnn.create_4d_tensor::<T>(
|
|
||||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
|
||||||
x_shape,
|
|
||||||
)?
|
|
||||||
} else {
|
|
||||||
let s = src_l.stride();
|
|
||||||
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
|
|
||||||
};
|
|
||||||
let w = cudnn.create_4d_filter::<T>(
|
|
||||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
|
||||||
[
|
|
||||||
params.c_out as i32,
|
|
||||||
params.c_in as i32,
|
|
||||||
params.k_size as i32,
|
|
||||||
1,
|
|
||||||
],
|
|
||||||
)?;
|
|
||||||
let l_out = params.l_out() as i32;
|
|
||||||
let y = cudnn.create_4d_tensor::<T>(
|
|
||||||
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
|
|
||||||
[params.b_size as i32, params.c_out as i32, l_out, 1],
|
|
||||||
)?;
|
|
||||||
let conv1d = ConvForward {
|
|
||||||
conv: &conv,
|
|
||||||
x: &x,
|
|
||||||
w: &w,
|
|
||||||
y: &y,
|
|
||||||
};
|
|
||||||
let alg = match params.cudnn_fwd_algo {
|
|
||||||
None => conv1d.pick_algorithm()?,
|
|
||||||
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
|
|
||||||
Some(CandleAlgo::ImplicitPrecompGemm) => {
|
|
||||||
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
|
|
||||||
}
|
|
||||||
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
|
|
||||||
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
|
|
||||||
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
|
|
||||||
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
|
|
||||||
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
|
|
||||||
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
|
|
||||||
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
|
|
||||||
};
|
|
||||||
let workspace_size = conv1d.get_workspace_size(alg)?;
|
|
||||||
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
|
|
||||||
unsafe {
|
|
||||||
conv1d.launch::<CudaSlice<u8>, _, _, _>(
|
|
||||||
alg,
|
|
||||||
Some(&mut workspace),
|
|
||||||
(T::one(), T::zero()),
|
|
||||||
src,
|
|
||||||
filter,
|
|
||||||
dst,
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -46,61 +46,11 @@ impl std::fmt::Debug for CudaDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaDevice {
|
impl std::ops::Deref for CudaDevice {
|
||||||
#[allow(clippy::missing_safety_doc)]
|
type Target = Arc<cudarc::driver::CudaStream>;
|
||||||
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
|
|
||||||
&self,
|
|
||||||
len: usize,
|
|
||||||
) -> Result<cudarc::driver::CudaSlice<T>> {
|
|
||||||
self.stream.alloc::<T>(len).w()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
|
fn deref(&self) -> &Self::Target {
|
||||||
&self,
|
&self.stream
|
||||||
len: usize,
|
|
||||||
) -> Result<cudarc::driver::CudaSlice<T>> {
|
|
||||||
self.stream.alloc_zeros::<T>(len).w()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn memcpy_htod<
|
|
||||||
T: cudarc::driver::DeviceRepr,
|
|
||||||
Src: cudarc::driver::HostSlice<T> + ?Sized,
|
|
||||||
Dst: cudarc::driver::DevicePtrMut<T>,
|
|
||||||
>(
|
|
||||||
&self,
|
|
||||||
src: &Src,
|
|
||||||
dst: &mut Dst,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.stream.memcpy_htod(src, dst).w()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn memcpy_dtov<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
|
|
||||||
&self,
|
|
||||||
src: &Src,
|
|
||||||
) -> Result<Vec<T>> {
|
|
||||||
self.stream.memcpy_dtov(src).w()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn memcpy_dtod<
|
|
||||||
T,
|
|
||||||
Src: cudarc::driver::DevicePtr<T>,
|
|
||||||
Dst: cudarc::driver::DevicePtrMut<T>,
|
|
||||||
>(
|
|
||||||
&self,
|
|
||||||
src: &Src,
|
|
||||||
dst: &mut Dst,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.stream.memcpy_dtod(src, dst).w()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn memcpy_stod<
|
|
||||||
T: cudarc::driver::DeviceRepr,
|
|
||||||
Src: cudarc::driver::HostSlice<T> + ?Sized,
|
|
||||||
>(
|
|
||||||
&self,
|
|
||||||
src: &Src,
|
|
||||||
) -> Result<cudarc::driver::CudaSlice<T>> {
|
|
||||||
self.stream.memcpy_stod(src).w()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,7 +126,7 @@ impl CudaDevice {
|
|||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
DType::U8 => {
|
DType::U8 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<u8>(elem_count)? };
|
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
let v = v as u8;
|
let v = v as u8;
|
||||||
@ -188,7 +138,7 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
DType::U32 => {
|
DType::U32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<u32>(elem_count)? };
|
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
let v = v as u32;
|
let v = v as u32;
|
||||||
@ -200,7 +150,7 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
DType::I64 => {
|
DType::I64 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<i64>(elem_count)? };
|
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
let v = v as i64;
|
let v = v as i64;
|
||||||
@ -212,7 +162,7 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<bf16>(elem_count)? };
|
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
let v = bf16::from_f64(v);
|
let v = bf16::from_f64(v);
|
||||||
@ -224,7 +174,7 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f16>(elem_count)? };
|
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
let v = f16::from_f64(v);
|
let v = f16::from_f64(v);
|
||||||
@ -236,7 +186,7 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count)? };
|
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
let v = v as f32;
|
let v = v as f32;
|
||||||
@ -248,7 +198,7 @@ impl CudaDevice {
|
|||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
// SAFETY: Set later by running the fill kernel.
|
// SAFETY: Set later by running the fill kernel.
|
||||||
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
let mut builder = self.stream.launch_builder(&func);
|
||||||
builder.arg(&data);
|
builder.arg(&data);
|
||||||
@ -375,31 +325,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
DType::U8 => {
|
DType::U8 => {
|
||||||
let data = self.alloc_zeros::<u8>(elem_count)?;
|
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
DType::U32 => {
|
DType::U32 => {
|
||||||
let data = self.alloc_zeros::<u32>(elem_count)?;
|
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
DType::I64 => {
|
DType::I64 => {
|
||||||
let data = self.alloc_zeros::<i64>(elem_count)?;
|
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
let data = self.alloc_zeros::<bf16>(elem_count)?;
|
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
let data = self.alloc_zeros::<f16>(elem_count)?;
|
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = self.alloc_zeros::<f32>(elem_count)?;
|
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let data = self.alloc_zeros::<f64>(elem_count)?;
|
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -423,12 +373,12 @@ impl BackendDevice for CudaDevice {
|
|||||||
.w()?
|
.w()?
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
|
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
|
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
@ -467,7 +417,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
.w()?
|
.w()?
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
|
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||||
curand
|
curand
|
||||||
.0
|
.0
|
||||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||||
@ -475,7 +425,7 @@ impl BackendDevice for CudaDevice {
|
|||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
|
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
@ -494,31 +444,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
DType::U8 => {
|
DType::U8 => {
|
||||||
let data = self.alloc::<u8>(elem_count)?;
|
let data = self.alloc::<u8>(elem_count).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
DType::U32 => {
|
DType::U32 => {
|
||||||
let data = self.alloc::<u32>(elem_count)?;
|
let data = self.alloc::<u32>(elem_count).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
DType::I64 => {
|
DType::I64 => {
|
||||||
let data = self.alloc::<i64>(elem_count)?;
|
let data = self.alloc::<i64>(elem_count).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
let data = self.alloc::<bf16>(elem_count)?;
|
let data = self.alloc::<bf16>(elem_count).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
let data = self.alloc::<f16>(elem_count)?;
|
let data = self.alloc::<f16>(elem_count).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let data = self.alloc::<f32>(elem_count)?;
|
let data = self.alloc::<f32>(elem_count).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let data = self.alloc::<f64>(elem_count)?;
|
let data = self.alloc::<f64>(elem_count).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -531,31 +481,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
let slice = match T::cpu_storage_ref(s) {
|
let slice = match T::cpu_storage_ref(s) {
|
||||||
CpuStorageRef::U8(storage) => {
|
CpuStorageRef::U8(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::U32(storage) => {
|
CpuStorageRef::U32(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::I64(storage) => {
|
CpuStorageRef::I64(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::BF16(storage) => {
|
CpuStorageRef::BF16(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F16(storage) => {
|
CpuStorageRef::F16(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F32(storage) => {
|
CpuStorageRef::F32(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorageRef::F64(storage) => {
|
CpuStorageRef::F64(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -568,31 +518,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorage::U32(storage) => {
|
CpuStorage::U32(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::I64(storage) => {
|
CpuStorage::I64(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorage::BF16(storage) => {
|
CpuStorage::BF16(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F16(storage) => {
|
CpuStorage::F16(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F64(storage) => {
|
CpuStorage::F64(storage) => {
|
||||||
let data = self.memcpy_stod(storage)?;
|
let data = self.memcpy_stod(storage).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -605,31 +555,31 @@ impl BackendDevice for CudaDevice {
|
|||||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||||
let slice = match storage {
|
let slice = match storage {
|
||||||
CpuStorage::U8(storage) => {
|
CpuStorage::U8(storage) => {
|
||||||
let data = self.memcpy_stod(&storage)?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::U8(data)
|
CudaStorageSlice::U8(data)
|
||||||
}
|
}
|
||||||
CpuStorage::U32(storage) => {
|
CpuStorage::U32(storage) => {
|
||||||
let data = self.memcpy_stod(&storage)?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::U32(data)
|
CudaStorageSlice::U32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::I64(storage) => {
|
CpuStorage::I64(storage) => {
|
||||||
let data = self.memcpy_stod(&storage)?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::I64(data)
|
CudaStorageSlice::I64(data)
|
||||||
}
|
}
|
||||||
CpuStorage::BF16(storage) => {
|
CpuStorage::BF16(storage) => {
|
||||||
let data = self.memcpy_stod(&storage)?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::BF16(data)
|
CudaStorageSlice::BF16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F16(storage) => {
|
CpuStorage::F16(storage) => {
|
||||||
let data = self.memcpy_stod(&storage)?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::F16(data)
|
CudaStorageSlice::F16(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F32(storage) => {
|
CpuStorage::F32(storage) => {
|
||||||
let data = self.memcpy_stod(&storage)?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::F32(data)
|
CudaStorageSlice::F32(data)
|
||||||
}
|
}
|
||||||
CpuStorage::F64(storage) => {
|
CpuStorage::F64(storage) => {
|
||||||
let data = self.memcpy_stod(&storage)?;
|
let data = self.memcpy_stod(&storage).w()?;
|
||||||
CudaStorageSlice::F64(data)
|
CudaStorageSlice::F64(data)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -39,7 +39,7 @@ impl SlicePtrOrNull<usize> {
|
|||||||
let ds = if l.is_contiguous() {
|
let ds = if l.is_contiguous() {
|
||||||
SlicePtrOrNull::Null
|
SlicePtrOrNull::Null
|
||||||
} else {
|
} else {
|
||||||
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat())?)
|
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat()).w()?)
|
||||||
};
|
};
|
||||||
Ok(ds)
|
Ok(ds)
|
||||||
}
|
}
|
||||||
@ -89,7 +89,7 @@ impl Map1 for Affine {
|
|||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), &kernels::AFFINE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), &kernels::AFFINE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el)? };
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -120,7 +120,7 @@ impl Map1 for Elu {
|
|||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), &kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), &kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el)? };
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -134,7 +134,6 @@ impl Map1 for Elu {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
struct Im2Col1D {
|
struct Im2Col1D {
|
||||||
l_k: usize,
|
l_k: usize,
|
||||||
stride: usize,
|
stride: usize,
|
||||||
@ -143,7 +142,6 @@ struct Im2Col1D {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Im2Col1D {
|
impl Im2Col1D {
|
||||||
#[allow(unused)]
|
|
||||||
fn l_out(&self, l: usize) -> usize {
|
fn l_out(&self, l: usize) -> usize {
|
||||||
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
|
||||||
}
|
}
|
||||||
@ -159,15 +157,15 @@ impl Map1 for Im2Col1D {
|
|||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let l_out = self.l_out(dims[2]);
|
let l_out = self.l_out(dims[2]);
|
||||||
let threads = dims[0] * l_out * dims[1];
|
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
|
||||||
let cfg = LaunchConfig::for_num_elems(threads as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
|
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(threads * self.l_k)? };
|
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, threads);
|
barg!(builder, dst_el);
|
||||||
barg!(builder, l_out);
|
barg!(builder, l_out);
|
||||||
barg!(builder, self.l_k);
|
barg!(builder, self.l_k);
|
||||||
barg!(builder, self.stride);
|
barg!(builder, self.stride);
|
||||||
@ -212,11 +210,11 @@ impl Map1 for Im2Col {
|
|||||||
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
|
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
|
||||||
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
|
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
|
||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
|
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), &kernels::CONV)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), &kernels::CONV)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(dst_el)? };
|
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, dst_el);
|
barg!(builder, dst_el);
|
||||||
barg!(builder, h_out);
|
barg!(builder, h_out);
|
||||||
@ -251,7 +249,7 @@ impl Map1 for Powf {
|
|||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), &kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), &kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el)? };
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -304,7 +302,9 @@ impl Map1Any for FastReduce<'_> {
|
|||||||
block_dim: (block_dim as u32, 1, 1),
|
block_dim: (block_dim as u32, 1, 1),
|
||||||
shared_mem_bytes: 0,
|
shared_mem_bytes: 0,
|
||||||
};
|
};
|
||||||
let ds = dev.memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat())?;
|
let ds = dev
|
||||||
|
.memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat())
|
||||||
|
.w()?;
|
||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let (name, check_empty, return_index) = match self.1 {
|
let (name, check_empty, return_index) = match self.1 {
|
||||||
ReduceOp::Sum => ("fast_sum", false, false),
|
ReduceOp::Sum => ("fast_sum", false, false),
|
||||||
@ -319,7 +319,7 @@ impl Map1Any for FastReduce<'_> {
|
|||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::REDUCE)?;
|
||||||
if return_index {
|
if return_index {
|
||||||
// SAFETY: filled in by the follow up kernel.
|
// SAFETY: filled in by the follow up kernel.
|
||||||
let out = unsafe { dev.alloc::<u32>(dst_el)? };
|
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, src_el);
|
barg!(builder, src_el);
|
||||||
barg!(builder, el_to_sum_per_block);
|
barg!(builder, el_to_sum_per_block);
|
||||||
@ -332,7 +332,7 @@ impl Map1Any for FastReduce<'_> {
|
|||||||
Ok(S::U32(out))
|
Ok(S::U32(out))
|
||||||
} else {
|
} else {
|
||||||
// SAFETY: filled in by the follow up kernel.
|
// SAFETY: filled in by the follow up kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, src_el);
|
barg!(builder, src_el);
|
||||||
barg!(builder, el_to_sum_per_block);
|
barg!(builder, el_to_sum_per_block);
|
||||||
@ -362,7 +362,7 @@ impl<U: UnaryOpT> Map1 for U {
|
|||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let mut out = unsafe { dev.alloc::<T>(el_count)? };
|
let mut out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el_count);
|
barg!(builder, el_count);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -403,7 +403,7 @@ impl Map1 for IndexSelect<'_> {
|
|||||||
};
|
};
|
||||||
let ids_shape = ids_l.shape();
|
let ids_shape = ids_l.shape();
|
||||||
let ids_dims = ids_shape.dims();
|
let ids_dims = ids_shape.dims();
|
||||||
let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat())?;
|
let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat()).w()?;
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
Some((o1, o2)) => src.slice(o1..o2),
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
|
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||||
@ -416,7 +416,7 @@ impl Map1 for IndexSelect<'_> {
|
|||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, dst_el);
|
barg!(builder, dst_el);
|
||||||
barg!(builder, ids_dims.len());
|
barg!(builder, ids_dims.len());
|
||||||
@ -471,7 +471,7 @@ impl Map1 for Gather<'_> {
|
|||||||
let ids_dim_sz = ids_l.dims()[dim];
|
let ids_dim_sz = ids_l.dims()[dim];
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el)? };
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, ids);
|
barg!(builder, ids);
|
||||||
@ -608,7 +608,7 @@ impl Map2 for Conv1D<'_> {
|
|||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), &kernels::CONV)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), &kernels::CONV)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let ds = if dims.len() == 3 {
|
let ds = if dims.len() == 3 {
|
||||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
||||||
} else if dims.len() == 2 {
|
} else if dims.len() == 2 {
|
||||||
@ -616,7 +616,7 @@ impl Map2 for Conv1D<'_> {
|
|||||||
} else {
|
} else {
|
||||||
crate::bail!("unexpected input shape for conv1d {dims:?}")
|
crate::bail!("unexpected input shape for conv1d {dims:?}")
|
||||||
};
|
};
|
||||||
let ds = dev.memcpy_stod(&ds)?;
|
let ds = dev.memcpy_stod(&ds).w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el, l_out, p.stride, p.padding, p.dilation);
|
barg!(builder, el, l_out, p.stride, p.padding, p.dilation);
|
||||||
builder.arg(&ds);
|
builder.arg(&ds);
|
||||||
@ -651,7 +651,7 @@ impl Map2 for Conv2D<'_> {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
|
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), &kernels::CONV)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), &kernels::CONV)?;
|
||||||
let ds = if dims.len() == 4 {
|
let ds = if dims.len() == 4 {
|
||||||
@ -659,7 +659,7 @@ impl Map2 for Conv2D<'_> {
|
|||||||
} else {
|
} else {
|
||||||
crate::bail!("unexpected input shape for conv2d {dims:?}")
|
crate::bail!("unexpected input shape for conv2d {dims:?}")
|
||||||
};
|
};
|
||||||
let ds = dev.memcpy_stod(&ds)?;
|
let ds = dev.memcpy_stod(&ds).w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el, out_w, out_h, p.stride, p.padding, p.dilation);
|
barg!(builder, el, out_w, out_h, p.stride, p.padding, p.dilation);
|
||||||
builder.arg(&ds);
|
builder.arg(&ds);
|
||||||
@ -687,7 +687,7 @@ impl Map1 for Col2Im1D {
|
|||||||
let stride = self.stride;
|
let stride = self.stride;
|
||||||
let l_out = (l_in - 1) * stride + k_size;
|
let l_out = (l_in - 1) * stride + k_size;
|
||||||
let dst_el = b_size * c_out * l_out;
|
let dst_el = b_size * c_out * l_out;
|
||||||
let mut im = unsafe { dev.alloc::<T>(dst_el)? };
|
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
|
|
||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), &kernels::CONV)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), &kernels::CONV)?;
|
||||||
@ -722,7 +722,7 @@ impl Map2 for ConvTranspose1D<'_> {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
|
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), &kernels::CONV)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), &kernels::CONV)?;
|
||||||
let ds = if dims.len() == 3 {
|
let ds = if dims.len() == 3 {
|
||||||
@ -730,7 +730,7 @@ impl Map2 for ConvTranspose1D<'_> {
|
|||||||
} else {
|
} else {
|
||||||
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
||||||
};
|
};
|
||||||
let ds = dev.memcpy_stod(&ds)?;
|
let ds = dev.memcpy_stod(&ds).w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, l_out);
|
barg!(builder, l_out);
|
||||||
@ -770,7 +770,7 @@ impl Map2 for ConvTranspose2D<'_> {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
|
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), &kernels::CONV)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), &kernels::CONV)?;
|
||||||
let ds = if dims.len() == 4 {
|
let ds = if dims.len() == 4 {
|
||||||
@ -778,7 +778,7 @@ impl Map2 for ConvTranspose2D<'_> {
|
|||||||
} else {
|
} else {
|
||||||
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
|
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
|
||||||
};
|
};
|
||||||
let ds = dev.memcpy_stod(&ds)?;
|
let ds = dev.memcpy_stod(&ds).w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, out_w);
|
barg!(builder, out_w);
|
||||||
@ -837,8 +837,8 @@ impl Map1 for Pool2D {
|
|||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(kname), &kernels::CONV)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(kname), &kernels::CONV)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let ds = dev.memcpy_stod(&ds)?;
|
let ds = dev.memcpy_stod(&ds).w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, self.w_k);
|
barg!(builder, self.w_k);
|
||||||
@ -876,8 +876,8 @@ impl Map1 for UpsampleNearest2D {
|
|||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), &kernels::CONV)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), &kernels::CONV)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el)? };
|
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
let ds = dev.memcpy_stod(&ds)?;
|
let ds = dev.memcpy_stod(&ds).w()?;
|
||||||
let scale_w = dims[2] as f64 / out_w as f64;
|
let scale_w = dims[2] as f64 / out_w as f64;
|
||||||
let scale_h = dims[3] as f64 / out_h as f64;
|
let scale_h = dims[3] as f64 / out_h as f64;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
@ -930,12 +930,13 @@ impl Map2 for WhereCond<'_> {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
let ds = dev
|
let ds = dev
|
||||||
.memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
|
.memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())
|
||||||
|
.w()?;
|
||||||
let t = &t.slice(layout_t.start_offset()..);
|
let t = &t.slice(layout_t.start_offset()..);
|
||||||
let f = &f.slice(layout_f.start_offset()..);
|
let f = &f.slice(layout_f.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::TERNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::TERNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el)? };
|
let out = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -966,13 +967,16 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
|
|||||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||||
SlicePtrOrNull::Null
|
SlicePtrOrNull::Null
|
||||||
} else {
|
} else {
|
||||||
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?)
|
SlicePtrOrNull::Ptr(
|
||||||
|
dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||||
|
.w()?,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::BINARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::BINARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(elem_count)? };
|
let out = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, elem_count);
|
barg!(builder, elem_count);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -1003,7 +1007,10 @@ impl Map2Any for Cmp {
|
|||||||
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||||
SlicePtrOrNull::Null
|
SlicePtrOrNull::Null
|
||||||
} else {
|
} else {
|
||||||
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?)
|
SlicePtrOrNull::Ptr(
|
||||||
|
dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())
|
||||||
|
.w()?,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
@ -1017,7 +1024,7 @@ impl Map2Any for Cmp {
|
|||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::BINARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::BINARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<u8>(elem_count)? };
|
let out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, elem_count);
|
barg!(builder, elem_count);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -1201,6 +1208,7 @@ fn gemm_config<T>(
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(StridedBatchedConfig {
|
Ok(StridedBatchedConfig {
|
||||||
batch_size: b as i32,
|
batch_size: b as i32,
|
||||||
gemm,
|
gemm,
|
||||||
@ -1261,7 +1269,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let func = dev.get_or_load_func(&kernel_name, &kernels::CAST)?;
|
let func = dev.get_or_load_func(&kernel_name, &kernels::CAST)?;
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
DType::U8 => {
|
DType::U8 => {
|
||||||
let out = unsafe { dev.alloc::<u8>(el)? };
|
let out = unsafe { dev.alloc::<u8>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -1272,7 +1280,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
CudaStorageSlice::U8(out)
|
CudaStorageSlice::U8(out)
|
||||||
}
|
}
|
||||||
DType::U32 => {
|
DType::U32 => {
|
||||||
let out = unsafe { dev.alloc::<u32>(el)? };
|
let out = unsafe { dev.alloc::<u32>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -1283,7 +1291,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
CudaStorageSlice::U32(out)
|
CudaStorageSlice::U32(out)
|
||||||
}
|
}
|
||||||
DType::I64 => {
|
DType::I64 => {
|
||||||
let out = unsafe { dev.alloc::<i64>(el)? };
|
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -1294,7 +1302,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
CudaStorageSlice::I64(out)
|
CudaStorageSlice::I64(out)
|
||||||
}
|
}
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
let out = unsafe { dev.alloc::<bf16>(el)? };
|
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -1305,7 +1313,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
CudaStorageSlice::BF16(out)
|
CudaStorageSlice::BF16(out)
|
||||||
}
|
}
|
||||||
DType::F16 => {
|
DType::F16 => {
|
||||||
let out = unsafe { dev.alloc::<f16>(el)? };
|
let out = unsafe { dev.alloc::<f16>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -1316,7 +1324,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
CudaStorageSlice::F16(out)
|
CudaStorageSlice::F16(out)
|
||||||
}
|
}
|
||||||
DType::F32 => {
|
DType::F32 => {
|
||||||
let out = unsafe { dev.alloc::<f32>(el)? };
|
let out = unsafe { dev.alloc::<f32>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -1327,7 +1335,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
CudaStorageSlice::F32(out)
|
CudaStorageSlice::F32(out)
|
||||||
}
|
}
|
||||||
DType::F64 => {
|
DType::F64 => {
|
||||||
let out = unsafe { dev.alloc::<f64>(el)? };
|
let out = unsafe { dev.alloc::<f64>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, el);
|
barg!(builder, el);
|
||||||
barg!(builder, dims.len());
|
barg!(builder, dims.len());
|
||||||
@ -1437,7 +1445,6 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
@ -1466,11 +1473,12 @@ impl BackendStorage for CudaStorage {
|
|||||||
let n = params.c_out;
|
let n = params.c_out;
|
||||||
let k = params.k_size * params.c_in;
|
let k = params.k_size * params.c_in;
|
||||||
let m = l_out;
|
let m = l_out;
|
||||||
let col_l = Layout::contiguous((b * m, k));
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
let res = if kernel_l.is_contiguous() {
|
let res = if kernel_l.is_contiguous() {
|
||||||
let kernel_l =
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
.transpose(1, 2)?
|
||||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = unsafe {
|
let mut kernel_c = unsafe {
|
||||||
@ -1478,9 +1486,10 @@ impl BackendStorage for CudaStorage {
|
|||||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||||
};
|
};
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l =
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
.transpose(1, 2)?
|
||||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
};
|
};
|
||||||
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
|
||||||
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
|
||||||
@ -1488,72 +1497,6 @@ impl BackendStorage for CudaStorage {
|
|||||||
Ok(res_t)
|
Ok(res_t)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cudnn")]
|
|
||||||
fn conv1d(
|
|
||||||
&self,
|
|
||||||
inp_l: &Layout,
|
|
||||||
kernel: &Self,
|
|
||||||
kernel_l: &Layout,
|
|
||||||
params: &crate::conv::ParamsConv1D,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let device = self.device().clone();
|
|
||||||
if !kernel_l.is_contiguous() {
|
|
||||||
let slice = Conv1D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
|
|
||||||
return Ok(Self { slice, device });
|
|
||||||
}
|
|
||||||
let l_out = params.l_out();
|
|
||||||
let dst_el = params.c_out * l_out * params.b_size;
|
|
||||||
let slice = match (&self.slice, &kernel.slice) {
|
|
||||||
(S::U8(inp), S::U8(k)) => {
|
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
|
||||||
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
|
|
||||||
crate::cudnn::launch_conv1d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
|
|
||||||
.map_err(crate::Error::wrap)?;
|
|
||||||
S::U8(out)
|
|
||||||
}
|
|
||||||
(S::BF16(inp), S::BF16(k)) => {
|
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
|
||||||
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
|
|
||||||
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
|
|
||||||
// version.
|
|
||||||
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
|
|
||||||
crate::cudnn::launch_conv1d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
|
|
||||||
.map_err(crate::Error::wrap)?;
|
|
||||||
S::BF16(out)
|
|
||||||
}
|
|
||||||
(S::F16(inp), S::F16(k)) => {
|
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
|
||||||
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
|
|
||||||
crate::cudnn::launch_conv1d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
|
|
||||||
.map_err(crate::Error::wrap)?;
|
|
||||||
S::F16(out)
|
|
||||||
}
|
|
||||||
(S::F32(inp), S::F32(k)) => {
|
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
|
||||||
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
|
|
||||||
crate::cudnn::launch_conv1d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
|
|
||||||
.map_err(crate::Error::wrap)?;
|
|
||||||
S::F32(out)
|
|
||||||
}
|
|
||||||
(S::F64(inp), S::F64(k)) => {
|
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
|
||||||
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
|
|
||||||
crate::cudnn::launch_conv1d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
|
|
||||||
.map_err(crate::Error::wrap)?;
|
|
||||||
S::F64(out)
|
|
||||||
}
|
|
||||||
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv1d does not support u32"))?,
|
|
||||||
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv1d does not support i64"))?,
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in conv1d"))?,
|
|
||||||
};
|
|
||||||
Ok(Self { slice, device })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv_transpose1d(
|
fn conv_transpose1d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
@ -1644,11 +1587,12 @@ impl BackendStorage for CudaStorage {
|
|||||||
let n = params.c_out;
|
let n = params.c_out;
|
||||||
let k = params.k_h * params.k_w * params.c_in;
|
let k = params.k_h * params.k_w * params.c_in;
|
||||||
let m = h_out * w_out;
|
let m = h_out * w_out;
|
||||||
let col_l = Layout::contiguous((b * m, k));
|
let col_l = Layout::contiguous((b, m, k));
|
||||||
let res = if kernel_l.is_contiguous() {
|
let res = if kernel_l.is_contiguous() {
|
||||||
let kernel_l =
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
.transpose(1, 2)?
|
||||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
} else {
|
} else {
|
||||||
// Make the kernel contiguous if not already the case.
|
// Make the kernel contiguous if not already the case.
|
||||||
let mut kernel_c = unsafe {
|
let mut kernel_c = unsafe {
|
||||||
@ -1656,9 +1600,10 @@ impl BackendStorage for CudaStorage {
|
|||||||
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
|
||||||
};
|
};
|
||||||
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
|
||||||
let kernel_l =
|
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
|
||||||
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
|
.transpose(1, 2)?
|
||||||
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
|
.broadcast_as((b, k, n))?;
|
||||||
|
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
|
||||||
};
|
};
|
||||||
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
let res_l = Layout::contiguous((b, h_out, w_out, n))
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
@ -1687,7 +1632,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(S::U8(inp), S::U8(k)) => {
|
(S::U8(inp), S::U8(k)) => {
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
|
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
|
||||||
crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
|
crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
|
||||||
.map_err(crate::Error::wrap)?;
|
.map_err(crate::Error::wrap)?;
|
||||||
S::U8(out)
|
S::U8(out)
|
||||||
@ -1695,7 +1640,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(S::BF16(inp), S::BF16(k)) => {
|
(S::BF16(inp), S::BF16(k)) => {
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
|
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
|
||||||
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
|
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
|
||||||
// version.
|
// version.
|
||||||
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
|
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
|
||||||
@ -1706,7 +1651,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(S::F16(inp), S::F16(k)) => {
|
(S::F16(inp), S::F16(k)) => {
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
|
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
|
||||||
crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
|
crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
|
||||||
.map_err(crate::Error::wrap)?;
|
.map_err(crate::Error::wrap)?;
|
||||||
S::F16(out)
|
S::F16(out)
|
||||||
@ -1714,7 +1659,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(S::F32(inp), S::F32(k)) => {
|
(S::F32(inp), S::F32(k)) => {
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
|
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
|
||||||
crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
|
crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
|
||||||
.map_err(crate::Error::wrap)?;
|
.map_err(crate::Error::wrap)?;
|
||||||
S::F32(out)
|
S::F32(out)
|
||||||
@ -1722,7 +1667,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(S::F64(inp), S::F64(k)) => {
|
(S::F64(inp), S::F64(k)) => {
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
let inp = &inp.slice(inp_l.start_offset()..);
|
||||||
let k = &k.slice(kernel_l.start_offset()..);
|
let k = &k.slice(kernel_l.start_offset()..);
|
||||||
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
|
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
|
||||||
crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
|
crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
|
||||||
.map_err(crate::Error::wrap)?;
|
.map_err(crate::Error::wrap)?;
|
||||||
S::F64(out)
|
S::F64(out)
|
||||||
@ -1838,7 +1783,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||||
let mut out = unsafe { dev.alloc::<bf16>(elem_count)? };
|
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
|
||||||
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||||
.w()?;
|
.w()?;
|
||||||
CudaStorageSlice::BF16(out)
|
CudaStorageSlice::BF16(out)
|
||||||
@ -1847,7 +1792,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
|
||||||
let mut out = unsafe { dev.alloc::<f16>(elem_count)? };
|
let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||||
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||||
.w()?;
|
.w()?;
|
||||||
CudaStorageSlice::F16(out)
|
CudaStorageSlice::F16(out)
|
||||||
@ -1856,7 +1801,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||||
let mut out = unsafe { dev.alloc::<f32>(elem_count)? };
|
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||||
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
|
||||||
.w()?;
|
.w()?;
|
||||||
CudaStorageSlice::F32(out)
|
CudaStorageSlice::F32(out)
|
||||||
@ -1865,7 +1810,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
let lhs = &lhs.slice(lhs_l.start_offset()..);
|
||||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
|
||||||
let mut out = unsafe { dev.alloc::<f64>(elem_count)? };
|
let mut out = unsafe { dev.alloc::<f64>(elem_count) }.w()?;
|
||||||
unsafe {
|
unsafe {
|
||||||
self.device
|
self.device
|
||||||
.blas
|
.blas
|
||||||
@ -1938,7 +1883,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_l.is_contiguous() {
|
if src_l.is_contiguous() {
|
||||||
dev.memcpy_dtod(&src, &mut dst)?
|
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_bf16", &kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_bf16", &kernels::UNARY)?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
@ -1954,7 +1899,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_l.is_contiguous() {
|
if src_l.is_contiguous() {
|
||||||
dev.memcpy_dtod(&src, &mut dst)?
|
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_f16", &kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_f16", &kernels::UNARY)?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
@ -1970,7 +1915,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_l.is_contiguous() {
|
if src_l.is_contiguous() {
|
||||||
dev.memcpy_dtod(&src, &mut dst)?
|
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_f32", &kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_f32", &kernels::UNARY)?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
@ -1986,7 +1931,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => {
|
(CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_l.is_contiguous() {
|
if src_l.is_contiguous() {
|
||||||
dev.memcpy_dtod(&src, &mut dst)?
|
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_u8", &kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_u8", &kernels::UNARY)?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
@ -2002,7 +1947,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_l.is_contiguous() {
|
if src_l.is_contiguous() {
|
||||||
dev.memcpy_dtod(&src, &mut dst)?
|
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_u32", &kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_u32", &kernels::UNARY)?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
@ -2018,7 +1963,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
|
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_l.is_contiguous() {
|
if src_l.is_contiguous() {
|
||||||
dev.memcpy_dtod(&src, &mut dst)?
|
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_i64", &kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_i64", &kernels::UNARY)?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
@ -2034,7 +1979,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_l.is_contiguous() {
|
if src_l.is_contiguous() {
|
||||||
dev.memcpy_dtod(&src, &mut dst)?
|
dev.memcpy_dtod(&src, &mut dst).w()?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_f64", &kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_f64", &kernels::UNARY)?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
|
@ -816,7 +816,7 @@ impl PthTensors {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `path` - Path to the pth file.
|
/// * `path` - Path to the pth file.
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||||
path: P,
|
path: P,
|
||||||
key: Option<&str>,
|
key: Option<&str>,
|
||||||
|
@ -73,7 +73,7 @@ fn dequantize_f32(
|
|||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
let nb = elem_count.div_ceil(256);
|
let nb = (elem_count + 255) / 256;
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||||
@ -99,7 +99,7 @@ fn dequantize_f32(
|
|||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count)? };
|
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
|
||||||
// See e.g.
|
// See e.g.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
@ -133,7 +133,7 @@ fn dequantize_f16(
|
|||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
let nb = elem_count.div_ceil(256);
|
let nb = (elem_count + 255) / 256;
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||||
@ -159,7 +159,7 @@ fn dequantize_f16(
|
|||||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
|
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
|
||||||
// See e.g.
|
// See e.g.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
@ -216,7 +216,7 @@ fn dequantize_mul_mat_vec(
|
|||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows)? };
|
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
|
||||||
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (block_num_y as u32, 1, 1),
|
grid_dim: (block_num_y as u32, 1, 1),
|
||||||
@ -256,7 +256,7 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
|
||||||
let y_size_in_bytes =
|
let y_size_in_bytes =
|
||||||
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
|
||||||
|
|
||||||
let kernel_name = match dtype {
|
let kernel_name = match dtype {
|
||||||
@ -274,12 +274,12 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
};
|
};
|
||||||
let kernel_name = format!("{kernel_name}{b_size}");
|
let kernel_name = format!("{kernel_name}{b_size}");
|
||||||
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(nrows * b_size)? };
|
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||||
let (nblocks, nwarps) = match b_size {
|
let (nblocks, nwarps) = match b_size {
|
||||||
1 => (nrows as u32, 4),
|
1 => (nrows as u32, 4),
|
||||||
2..=4 => ((nrows as u32).div_ceil(2), 4),
|
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
||||||
5..=8 => ((nrows as u32).div_ceil(2), 2),
|
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
||||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||||
};
|
};
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
@ -329,7 +329,7 @@ fn mul_mat_via_q8_1(
|
|||||||
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
let k_padded = pad(k, MATRIX_ROW_PADDING);
|
||||||
let y_size_in_bytes =
|
let y_size_in_bytes =
|
||||||
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
|
||||||
|
|
||||||
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
let (kernel_name, mmq_x, mmq_y) = match dtype {
|
||||||
@ -346,7 +346,7 @@ fn mul_mat_via_q8_1(
|
|||||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
|
||||||
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols)? };
|
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
grid_dim: (
|
grid_dim: (
|
||||||
ceil_div(x_rows, mmq_y) as u32,
|
ceil_div(x_rows, mmq_y) as u32,
|
||||||
@ -378,7 +378,7 @@ impl QCudaStorage {
|
|||||||
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
|
||||||
let padded_size_in_bytes =
|
let padded_size_in_bytes =
|
||||||
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
|
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
|
||||||
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes)?;
|
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
|
||||||
Ok(QCudaStorage {
|
Ok(QCudaStorage {
|
||||||
data: PaddedCudaSlice {
|
data: PaddedCudaSlice {
|
||||||
inner,
|
inner,
|
||||||
@ -425,7 +425,8 @@ impl QCudaStorage {
|
|||||||
|
|
||||||
let buffer = self
|
let buffer = self
|
||||||
.device
|
.device
|
||||||
.memcpy_dtov(&self.data.inner.slice(..self.data.len))?;
|
.memcpy_dtov(&self.data.inner.slice(..self.data.len))
|
||||||
|
.w()?;
|
||||||
let mut out = vec![0.0; elem_count];
|
let mut out = vec![0.0; elem_count];
|
||||||
let block_len = elem_count / self.dtype.block_size();
|
let block_len = elem_count / self.dtype.block_size();
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
@ -456,7 +457,9 @@ impl QCudaStorage {
|
|||||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||||
// Run the quantization on cpu.
|
// Run the quantization on cpu.
|
||||||
let src = match &src.slice {
|
let src = match &src.slice {
|
||||||
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
|
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||||
|
self.device.memcpy_dtov(data).w()?
|
||||||
|
}
|
||||||
_ => crate::bail!("only f32 can be quantized"),
|
_ => crate::bail!("only f32 can be quantized"),
|
||||||
};
|
};
|
||||||
let src_len = src.len();
|
let src_len = src.len();
|
||||||
@ -466,9 +469,10 @@ impl QCudaStorage {
|
|||||||
let data = qcpu_storage.data()?;
|
let data = qcpu_storage.data()?;
|
||||||
let padded_len =
|
let padded_len =
|
||||||
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
|
||||||
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
|
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
|
||||||
self.device
|
self.device
|
||||||
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
|
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))
|
||||||
|
.w()?;
|
||||||
self.data = PaddedCudaSlice {
|
self.data = PaddedCudaSlice {
|
||||||
inner,
|
inner,
|
||||||
len: data.len(),
|
len: data.len(),
|
||||||
@ -602,8 +606,10 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
|||||||
};
|
};
|
||||||
let dtype = T::DTYPE;
|
let dtype = T::DTYPE;
|
||||||
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
|
||||||
let mut inner = unsafe { device.alloc::<u8>(padded_len)? };
|
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
|
||||||
device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?;
|
device
|
||||||
|
.memcpy_htod(data, &mut inner.slice_mut(..data.len()))
|
||||||
|
.w()?;
|
||||||
Ok(QStorage::Cuda(QCudaStorage {
|
Ok(QStorage::Cuda(QCudaStorage {
|
||||||
data: PaddedCudaSlice {
|
data: PaddedCudaSlice {
|
||||||
inner,
|
inner,
|
||||||
@ -625,9 +631,9 @@ mod test {
|
|||||||
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
let el_padded = pad(el, MATRIX_ROW_PADDING);
|
||||||
let y_size_in_bytes =
|
let y_size_in_bytes =
|
||||||
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
|
||||||
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
|
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
|
||||||
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
|
||||||
let y = dev.memcpy_stod(&vs)?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -637,7 +643,7 @@ mod test {
|
|||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let ncols = 256;
|
let ncols = 256;
|
||||||
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
|
||||||
let y = dev.memcpy_stod(&vs)?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_vec_via_q8_1(
|
let cuda_storage = mul_mat_vec_via_q8_1(
|
||||||
@ -650,7 +656,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
assert_eq!(vs.len(), 1);
|
assert_eq!(vs.len(), 1);
|
||||||
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
|
||||||
// Q8 means 1/256 precision.
|
// Q8 means 1/256 precision.
|
||||||
@ -665,7 +671,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
assert_eq!(vs.len(), 1);
|
assert_eq!(vs.len(), 1);
|
||||||
assert_eq!(vs[0], 5561851.0);
|
assert_eq!(vs[0], 5561851.0);
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -676,7 +682,7 @@ mod test {
|
|||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let ncols = 256;
|
let ncols = 256;
|
||||||
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
|
||||||
let y = dev.memcpy_stod(&vs)?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
@ -690,7 +696,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let vs = dev.memcpy_dtov(&vs.slice(..))?;
|
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
|
|
||||||
/*
|
/*
|
||||||
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
|
||||||
@ -717,7 +723,7 @@ mod test {
|
|||||||
let dev = CudaDevice::new(0)?;
|
let dev = CudaDevice::new(0)?;
|
||||||
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
let (x_rows, ncols, y_cols) = (4, 16, 2048);
|
||||||
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
|
||||||
let y = dev.memcpy_stod(&vs)?;
|
let y = dev.memcpy_stod(&vs).w()?;
|
||||||
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
|
||||||
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
|
||||||
let cuda_storage = mul_mat_via_q8_1(
|
let cuda_storage = mul_mat_via_q8_1(
|
||||||
@ -731,7 +737,7 @@ mod test {
|
|||||||
&dev,
|
&dev,
|
||||||
)?;
|
)?;
|
||||||
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
let vs = cuda_storage.as_cuda_slice::<f32>()?;
|
||||||
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
|
let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -76,7 +76,7 @@ mod cuda {
|
|||||||
Some((o1, o2)) => src.slice(o1..o2),
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let dst = unsafe { dev.alloc::<u32>(elem_count)? };
|
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
|
||||||
let func = if self.asc {
|
let func = if self.asc {
|
||||||
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
|
||||||
} else {
|
} else {
|
||||||
|
@ -53,20 +53,6 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
let res = {
|
|
||||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
|
||||||
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
|
|
||||||
};
|
|
||||||
assert_eq!(res.dims(), [3, 2, 5]);
|
|
||||||
// Same as pytorch default padding: use zeros.
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
|
||||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
|
||||||
);
|
|
||||||
|
|
||||||
let w = w.transpose(0, 1)?;
|
let w = w.transpose(0, 1)?;
|
||||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||||
@ -177,22 +163,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
|||||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let res = {
|
|
||||||
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
|
|
||||||
t.conv2d(&w, 0, 1, 1, 1)?
|
|
||||||
};
|
|
||||||
assert_eq!(res.dims(), [3, 2, 3, 3]);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
|
|
||||||
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
|
|
||||||
);
|
|
||||||
assert_eq!(
|
|
||||||
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
|
|
||||||
[
|
|
||||||
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
|
|
||||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
|
||||||
]
|
|
||||||
);
|
|
||||||
|
|
||||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ bindgen_cuda = { version = "0.1.1", optional = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
|
||||||
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
|
cudnn = ["candle/cudnn"]
|
||||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
@ -69,7 +69,6 @@ metal = ["candle/metal", "candle-nn/metal"]
|
|||||||
microphone = ["cpal", "rubato"]
|
microphone = ["cpal", "rubato"]
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
mimi = ["cpal", "symphonia", "rubato"]
|
mimi = ["cpal", "symphonia", "rubato"]
|
||||||
snac = ["cpal", "symphonia", "rubato"]
|
|
||||||
depth_anything_v2 = ["palette", "enterpolation"]
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
@ -108,10 +107,6 @@ required-features = ["candle-datasets"]
|
|||||||
name = "mimi"
|
name = "mimi"
|
||||||
required-features = ["mimi"]
|
required-features = ["mimi"]
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "snac"
|
|
||||||
required-features = ["snac"]
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "encodec"
|
name = "encodec"
|
||||||
required-features = ["encodec"]
|
required-features = ["encodec"]
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
# Conversational Speech Model (CSM)
|
|
||||||
|
|
||||||
CSM is a speech generation model from Sesame,
|
|
||||||
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
|
|
||||||
|
|
||||||
It can generate a conversational speech between two different speakers.
|
|
||||||
The speakers turn are delimited by the `|` character in the prompt.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example csm --features cuda -r -- \
|
|
||||||
--voices candle-examples/examples/csm/voices.safetensors \
|
|
||||||
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
|
|
||||||
```
|
|
||||||
|
|
@ -34,18 +34,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
/// The prompt to be used for the generation, use a | to separate the speakers.
|
#[arg(long, default_value = "[0]Hey how are you doing?")]
|
||||||
#[arg(long, default_value = "Hey how are you doing today?")]
|
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
/// The voices to be used, in safetensors format.
|
|
||||||
#[arg(long)]
|
|
||||||
voices: String,
|
|
||||||
|
|
||||||
/// The output file using the wav format.
|
|
||||||
#[arg(long, default_value = "out.wav")]
|
|
||||||
out_file: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long, default_value_t = 0.7)]
|
#[arg(long, default_value_t = 0.7)]
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
@ -171,7 +162,7 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (mut model, device) = {
|
let (mut model, device) = {
|
||||||
let dtype = device.bf16_default_to_f32();
|
let dtype = DType::F32;
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
(model, device)
|
(model, device)
|
||||||
@ -186,58 +177,45 @@ fn main() -> Result<()> {
|
|||||||
let cb = config.audio_num_codebooks;
|
let cb = config.audio_num_codebooks;
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
if args.prompt.ends_with(".safetensors") {
|
||||||
let voices = candle::safetensors::load(args.voices, &device)?;
|
let prompt = candle::safetensors::load(args.prompt, &device)?;
|
||||||
let mut lp = candle_transformers::generation::LogitsProcessor::new(
|
let mut tokens = prompt
|
||||||
args.seed,
|
.get("tokens")
|
||||||
Some(args.temperature),
|
.expect("no tokens in prompt")
|
||||||
None,
|
.to_dtype(DType::U32)?;
|
||||||
);
|
let mut mask = prompt.get("mask").expect("no mask in prompt").clone();
|
||||||
let tokens = voices
|
println!("tokens:\n{tokens:?}");
|
||||||
.get("tokens")
|
println!("mask:\n{mask:?}");
|
||||||
.expect("no tokens in prompt")
|
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, None, None);
|
||||||
.to_dtype(DType::U32)?;
|
let mut const_mask = vec![1u8; cb];
|
||||||
let mask = voices.get("mask").expect("no mask in prompt").clone();
|
const_mask.push(0);
|
||||||
|
let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?;
|
||||||
let mut pos = 0;
|
let mut pos = 0;
|
||||||
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
let mut all_tokens = vec![];
|
||||||
pos += tokens.dim(1)?;
|
for i in 0.. {
|
||||||
|
let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
let mut all_pcms = vec![];
|
|
||||||
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
|
|
||||||
println!("{prompt:?}");
|
|
||||||
let speaker_idx = turn_idx % 2;
|
|
||||||
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
|
|
||||||
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
|
|
||||||
|
|
||||||
let mut generated_tokens = vec![];
|
|
||||||
loop {
|
|
||||||
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
|
||||||
pos += tokens.dim(1)?;
|
pos += tokens.dim(1)?;
|
||||||
let is_done = frame.iter().all(|&x| x == 0);
|
frame.push(0);
|
||||||
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
|
if frame.iter().all(|&x| x == 0) {
|
||||||
print!("\rframe {pos}");
|
|
||||||
if is_done {
|
|
||||||
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
|
||||||
pos += tokens.dim(1)?;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
generated_tokens.push(tokens.clone());
|
println!("frame {i} {pos}:\n{frame:?}");
|
||||||
|
tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?;
|
||||||
|
all_tokens.push(tokens.clone());
|
||||||
|
mask = const_mask.clone();
|
||||||
}
|
}
|
||||||
println!();
|
let all_tokens = Tensor::cat(&all_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
||||||
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
println!("all_tokens:\n{all_tokens:?}");
|
||||||
let pcm = mimi_model.decode(&generated_tokens)?;
|
let pcm = mimi_model.decode(&all_tokens)?;
|
||||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||||
all_pcms.push(pcm);
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
|
let mut output = std::fs::File::create("out.wav")?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||||
|
} else {
|
||||||
|
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
||||||
|
println!("{prompt:?}");
|
||||||
}
|
}
|
||||||
let pcm = Tensor::cat(&all_pcms, 0)?;
|
|
||||||
let pcm = pcm.to_vec1::<f32>()?;
|
|
||||||
println!("writing output file {}", args.out_file);
|
|
||||||
let mut output = std::fs::File::create(args.out_file)?;
|
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Binary file not shown.
@ -68,7 +68,7 @@ impl CustomOp1 for LayerNorm {
|
|||||||
Some((o1, o2)) => slice.slice(o1..o2),
|
Some((o1, o2)) => slice.slice(o1..o2),
|
||||||
};
|
};
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }?;
|
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||||
let func =
|
let func =
|
||||||
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
|
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||||
let cfg = LaunchConfig {
|
let cfg = LaunchConfig {
|
||||||
|
@ -8,7 +8,7 @@ DistilBert is used to compute the sentence embeddings for a prompt. The model we
|
|||||||
are downloaded from the hub on the first run.
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||||
@ -20,25 +20,3 @@ $ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
|||||||
> Tensor[[1, 7, 768], f32]
|
> Tensor[[1, 7, 768], f32]
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Masked Token
|
|
||||||
|
|
||||||
DistilBert is used to compute the top K choices for a masked token.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example distilbert -- --prompt "The capital of France is [MASK]." --top-k 10
|
|
||||||
|
|
||||||
> Input: The capital of France is [MASK].
|
|
||||||
> Predictions for [MASK] at position 6:
|
|
||||||
> 1: marseille (probability: 12.14%)
|
|
||||||
> 2: paris (probability: 10.84%)
|
|
||||||
> 3: toulouse (probability: 8.57%)
|
|
||||||
> 4: lyon (probability: 7.61%)
|
|
||||||
> 5: montpellier (probability: 5.18%)
|
|
||||||
> 6: bordeaux (probability: 4.88%)
|
|
||||||
> 7: nantes (probability: 4.82%)
|
|
||||||
> 8: lille (probability: 4.07%)
|
|
||||||
> 9: strasbourg (probability: 3.12%)
|
|
||||||
> 10: cannes (probability: 3.04%)
|
|
||||||
|
|
||||||
```
|
|
@ -3,48 +3,15 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
use candle_transformers::models::distilbert::{
|
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||||
Config, DistilBertForMaskedLM, DistilBertModel, DTYPE,
|
|
||||||
};
|
|
||||||
|
|
||||||
use anyhow::{Context, Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::Parser;
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::path::PathBuf;
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
enum ModelType {
|
|
||||||
Masked(DistilBertForMaskedLM),
|
|
||||||
UnMasked(DistilBertModel),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ModelType {
|
|
||||||
fn device(&self) -> &Device {
|
|
||||||
match self {
|
|
||||||
ModelType::Masked(model) => &model.bert.device,
|
|
||||||
ModelType::UnMasked(model) => &model.device,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
ModelType::Masked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
|
||||||
ModelType::UnMasked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
#[value(name = "distilbert")]
|
|
||||||
DistilBert,
|
|
||||||
|
|
||||||
#[value(name = "distilbertformaskedlm")]
|
|
||||||
DistilbertForMaskedLM,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -56,14 +23,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
#[arg(long, default_value = "distilbert")]
|
|
||||||
model: Which,
|
|
||||||
|
|
||||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
/// Revision or branch
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
@ -79,246 +42,94 @@ struct Args {
|
|||||||
#[arg(long, default_value = "1")]
|
#[arg(long, default_value = "1")]
|
||||||
n: usize,
|
n: usize,
|
||||||
|
|
||||||
/// Number of top predictions to show for each mask
|
/// L2 normalization for embeddings.
|
||||||
#[arg(long, default_value = "5")]
|
#[arg(long, default_value = "true")]
|
||||||
top_k: usize,
|
normalize_embeddings: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
fn build_model_and_tokenizer(&self) -> Result<(ModelType, Tokenizer)> {
|
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||||
let device = candle_examples::device(self.cpu)?;
|
let device = candle_examples::device(self.cpu)?;
|
||||||
|
|
||||||
let (model_id, revision) = self.resolve_model_and_revision();
|
|
||||||
let (config_path, tokenizer_path, weights_path) =
|
|
||||||
self.download_model_files(&model_id, &revision)?;
|
|
||||||
|
|
||||||
let config = std::fs::read_to_string(config_path)?;
|
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let vb = self.load_variables(&weights_path, &device)?;
|
|
||||||
let model = self.create_model(&config, vb)?;
|
|
||||||
|
|
||||||
Ok((model, tokenizer))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_model_and_revision(&self) -> (String, String) {
|
|
||||||
let default_model = "distilbert-base-uncased".to_string();
|
let default_model = "distilbert-base-uncased".to_string();
|
||||||
let default_revision = "main".to_string();
|
let default_revision = "main".to_string();
|
||||||
|
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||||
match (self.model_id.clone(), self.revision.clone()) {
|
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
(Some(model_id), None) => (model_id, default_revision),
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
(None, Some(revision)) => (default_model, revision),
|
(None, Some(revision)) => (default_model, revision),
|
||||||
(None, None) => (default_model, default_revision),
|
(None, None) => (default_model, default_revision),
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn download_model_files(
|
|
||||||
&self,
|
|
||||||
model_id: &str,
|
|
||||||
revision: &str,
|
|
||||||
) -> Result<(PathBuf, PathBuf, PathBuf)> {
|
|
||||||
let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
|
|
||||||
let api = Api::new()?;
|
|
||||||
let api = api.repo(repo);
|
|
||||||
|
|
||||||
let config = api.get("config.json")?;
|
|
||||||
let tokenizer = api.get("tokenizer.json")?;
|
|
||||||
let weights = if self.use_pth {
|
|
||||||
api.get("pytorch_model.bin")?
|
|
||||||
} else {
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((config, tokenizer, weights))
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
}
|
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||||
|
let api = Api::new()?;
|
||||||
|
let api = api.repo(repo);
|
||||||
|
let config = api.get("config.json")?;
|
||||||
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
|
let weights = if self.use_pth {
|
||||||
|
api.get("pytorch_model.bin")?
|
||||||
|
} else {
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
};
|
||||||
|
(config, tokenizer, weights)
|
||||||
|
};
|
||||||
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder> {
|
let vb = if self.use_pth {
|
||||||
if self.use_pth {
|
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||||
Ok(VarBuilder::from_pth(weights_path, DTYPE, device)?)
|
|
||||||
} else {
|
} else {
|
||||||
Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device)? })
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||||
}
|
};
|
||||||
|
let model = DistilBertModel::load(vb, &config)?;
|
||||||
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||||
match self.model {
|
let mask: Vec<_> = (0..size)
|
||||||
Which::DistilbertForMaskedLM => {
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
|
.collect();
|
||||||
}
|
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||||
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let _guard = setup_tracing(&args);
|
let _guard = if args.tracing {
|
||||||
|
|
||||||
let (model, tokenizer) = args.build_model_and_tokenizer()?;
|
|
||||||
let device = model.device();
|
|
||||||
|
|
||||||
let (token_ids, mask) = prepare_inputs(&args, &tokenizer, device)?;
|
|
||||||
let output = model.forward(&token_ids, &mask)?;
|
|
||||||
|
|
||||||
process_output(&model, &output, &token_ids, &tokenizer, &args)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn setup_tracing(args: &Args) -> Option<impl Drop> {
|
|
||||||
if args.tracing {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
println!("tracing...");
|
println!("tracing...");
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
Some(guard)
|
Some(guard)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
};
|
||||||
}
|
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
|
let device = &model.device;
|
||||||
|
|
||||||
fn prepare_inputs(args: &Args, tokenizer: &Tokenizer, device: &Device) -> Result<(Tensor, Tensor)> {
|
let tokenizer = tokenizer
|
||||||
let mut binding = tokenizer.clone();
|
|
||||||
let tokenizer_configured = binding
|
|
||||||
.with_padding(None)
|
.with_padding(None)
|
||||||
.with_truncation(None)
|
.with_truncation(None)
|
||||||
.map_err(E::msg)?;
|
.map_err(E::msg)?;
|
||||||
|
let tokens = tokenizer
|
||||||
let tokens = tokenizer_configured
|
.encode(args.prompt, true)
|
||||||
.encode(args.prompt.clone(), true)
|
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
|
let mask = get_mask(tokens.len(), device);
|
||||||
|
|
||||||
let mask = match args.model {
|
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||||
Which::DistilbertForMaskedLM => attention_mask_maskedlm(tokenizer, &args.prompt, device)?,
|
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||||
Which::DistilBert => attention_mask(tokens.len(), device)?,
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>()?);
|
let ys = model.forward(&token_ids, &mask)?;
|
||||||
|
println!("{ys}");
|
||||||
Ok((token_ids, mask))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn process_output(
|
|
||||||
model: &ModelType,
|
|
||||||
output: &Tensor,
|
|
||||||
token_ids: &Tensor,
|
|
||||||
tokenizer: &Tokenizer,
|
|
||||||
args: &Args,
|
|
||||||
) -> Result<()> {
|
|
||||||
match model {
|
|
||||||
ModelType::UnMasked(_) => {
|
|
||||||
println!("embeddings");
|
|
||||||
println!("{output}");
|
|
||||||
}
|
|
||||||
ModelType::Masked(_) => {
|
|
||||||
process_masked_output(output, token_ids, tokenizer, args)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn process_masked_output(
|
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||||
output: &Tensor,
|
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||||
token_ids: &Tensor,
|
|
||||||
tokenizer: &Tokenizer,
|
|
||||||
args: &Args,
|
|
||||||
) -> Result<()> {
|
|
||||||
let input_ids_vec = token_ids.to_vec2::<u32>()?;
|
|
||||||
let mask_token_id = tokenizer
|
|
||||||
.token_to_id("[MASK]")
|
|
||||||
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
|
||||||
|
|
||||||
println!("\nInput: {}", args.prompt);
|
|
||||||
|
|
||||||
for (token_idx, &token_id) in input_ids_vec[0].iter().enumerate() {
|
|
||||||
if token_id == mask_token_id {
|
|
||||||
println!("Predictions for [MASK] at position {}:", token_idx);
|
|
||||||
|
|
||||||
let pos_logits = output.get(0)?.get(token_idx)?;
|
|
||||||
let probs = candle_nn::ops::softmax(&pos_logits, 0)?;
|
|
||||||
let (top_values, top_indices) = get_top_k(&probs, args.top_k)?;
|
|
||||||
|
|
||||||
let values = top_values.to_vec1::<f32>()?;
|
|
||||||
let indices = top_indices.to_vec1::<u32>()?;
|
|
||||||
|
|
||||||
for (i, (&token_id, &prob)) in indices.iter().zip(values.iter()).enumerate() {
|
|
||||||
let token = tokenizer.decode(&[token_id], false).map_err(E::msg)?;
|
|
||||||
println!(
|
|
||||||
" {}: {:15} (probability: {:.2}%)",
|
|
||||||
i + 1,
|
|
||||||
token,
|
|
||||||
prob * 100.0
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_top_k(tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
|
|
||||||
let n = tensor.dims().iter().product::<usize>();
|
|
||||||
let k = std::cmp::min(k, n);
|
|
||||||
|
|
||||||
let values = tensor.to_vec1::<f32>()?;
|
|
||||||
let mut value_indices: Vec<(f32, usize)> = values
|
|
||||||
.into_iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(idx, val)| (val, idx))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
|
||||||
|
|
||||||
let top_k_values: Vec<f32> = value_indices.iter().take(k).map(|(val, _)| *val).collect();
|
|
||||||
let top_k_indices: Vec<u32> = value_indices
|
|
||||||
.iter()
|
|
||||||
.take(k)
|
|
||||||
.map(|(_, idx)| *idx as u32)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let device = tensor.device();
|
|
||||||
let top_values = Tensor::from_vec(top_k_values, (k,), device)?;
|
|
||||||
let top_indices = Tensor::from_vec(top_k_indices, (k,), device)?;
|
|
||||||
|
|
||||||
Ok((top_values, top_indices))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn attention_mask(size: usize, device: &Device) -> Result<Tensor> {
|
|
||||||
let mask: Vec<_> = (0..size)
|
|
||||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
|
||||||
.collect();
|
|
||||||
Ok(Tensor::from_slice(&mask, (size, size), device)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn attention_mask_maskedlm(tokenizer: &Tokenizer, input: &str, device: &Device) -> Result<Tensor> {
|
|
||||||
let tokens = tokenizer.encode(input, true).map_err(E::msg)?;
|
|
||||||
let seq_len = tokens.get_attention_mask().to_vec().len();
|
|
||||||
|
|
||||||
let mask_token_id = tokenizer
|
|
||||||
.token_to_id("[MASK]")
|
|
||||||
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
|
||||||
|
|
||||||
let mut attention_mask_vec = Vec::with_capacity(seq_len * seq_len);
|
|
||||||
|
|
||||||
let ids = tokens.get_ids();
|
|
||||||
for _ in 0..seq_len {
|
|
||||||
for id in ids.iter() {
|
|
||||||
let mask_value = if id == &mask_token_id { 1u8 } else { 0u8 };
|
|
||||||
attention_mask_vec.push(mask_value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let shape = (1, 1, seq_len, seq_len);
|
|
||||||
let mask = Tensor::from_vec(attention_mask_vec, shape, device)?;
|
|
||||||
|
|
||||||
Ok(mask)
|
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ impl TextGeneration {
|
|||||||
Sampling::ArgMax
|
Sampling::ArgMax
|
||||||
} else {
|
} else {
|
||||||
match (top_k, top_p) {
|
match (top_k, top_p) {
|
||||||
(None, None) => Sampling::GumbelSoftmax { temperature },
|
(None, None) => Sampling::All { temperature },
|
||||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
@ -21,7 +21,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
self.d_model.div_ceil(16)
|
(self.d_model + 15) / 16
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_conv(&self) -> usize {
|
fn d_conv(&self) -> usize {
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
# Orpheus
|
|
||||||
|
|
||||||
Orpheus is a 3B text-to-speech model based on Llama.
|
|
||||||
|
|
||||||
- Weights on HuggingFace
|
|
||||||
[canopylabs/orpheus-3b-0.1-ft](https://huggingface.co/canopylabs/orpheus-3b-0.1-ft).
|
|
||||||
- Code on GitHub [canopyai/Orpheus-TTS](https://github.com/canopyai/Orpheus-TTS).
|
|
||||||
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cargo run --example orpheus --features cuda -r
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
@ -1,329 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle::{DType, Device, IndexOp, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::models::llama::{Cache, Llama, LlamaConfig};
|
|
||||||
use candle_transformers::models::snac::{Config as SnacConfig, Model as SnacModel};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43
|
|
||||||
const STOP_TOKEN_ID: u32 = 128258;
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
/// Display the token for the specified prompt.
|
|
||||||
#[arg(long)]
|
|
||||||
verbose_prompt: bool,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "Hey, how are you doing today?")]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long, default_value_t = 0.6)]
|
|
||||||
temperature: f64,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// Only sample among the top K samples.
|
|
||||||
#[arg(long)]
|
|
||||||
top_k: Option<usize>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
config_file: Option<String>,
|
|
||||||
|
|
||||||
/// The output wav file.
|
|
||||||
#[arg(long, default_value = "out.wav")]
|
|
||||||
out_file: String,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "3b-0.1-ft")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "tara")]
|
|
||||||
voice: Voice,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
use_flash_attn: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
|
||||||
enum Voice {
|
|
||||||
#[value(name = "tara")]
|
|
||||||
Tara,
|
|
||||||
#[value(name = "leah")]
|
|
||||||
Leah,
|
|
||||||
#[value(name = "jess")]
|
|
||||||
Jess,
|
|
||||||
#[value(name = "leo")]
|
|
||||||
Leo,
|
|
||||||
#[value(name = "dan")]
|
|
||||||
Dan,
|
|
||||||
#[value(name = "mia")]
|
|
||||||
Mia,
|
|
||||||
#[value(name = "zac")]
|
|
||||||
Zac,
|
|
||||||
#[value(name = "zoe")]
|
|
||||||
Zoe,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Voice {
|
|
||||||
fn as_str(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Voice::Tara => "tara",
|
|
||||||
Voice::Leah => "leah",
|
|
||||||
Voice::Jess => "jess",
|
|
||||||
Voice::Leo => "leo",
|
|
||||||
Voice::Dan => "dan",
|
|
||||||
Voice::Mia => "mia",
|
|
||||||
Voice::Zac => "zac",
|
|
||||||
Voice::Zoe => "zoe",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
#[value(name = "3b-0.1-ft")]
|
|
||||||
ThreeB0_1Ft,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
let prompt = args.prompt.clone();
|
|
||||||
let mut model = Model::load(args)?;
|
|
||||||
model.run(&prompt)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Model {
|
|
||||||
model: Llama,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
logits_processor: candle_transformers::generation::LogitsProcessor,
|
|
||||||
cache: Cache,
|
|
||||||
device: Device,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
snac: SnacModel,
|
|
||||||
out_file: String,
|
|
||||||
voice: Voice,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_snac(device: &Device) -> Result<SnacModel> {
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let m = api.model("hubertsiuzdak/snac_24khz".to_string());
|
|
||||||
let config = m.get("config.json")?;
|
|
||||||
let config: SnacConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
|
|
||||||
let m = api.model("lmz/candle-snac".to_string());
|
|
||||||
let model = m.get("snac_24khz.safetensors")?;
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, device)? };
|
|
||||||
let model = SnacModel::new(&config, vb)?;
|
|
||||||
Ok(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
fn load(args: Args) -> Result<Self> {
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model_id) => model_id.to_string(),
|
|
||||||
None => match args.which {
|
|
||||||
Which::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let revision = match args.revision {
|
|
||||||
Some(r) => r,
|
|
||||||
None => "main".to_string(),
|
|
||||||
};
|
|
||||||
let repo = api.repo(hf_hub::Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
hf_hub::RepoType::Model,
|
|
||||||
revision,
|
|
||||||
));
|
|
||||||
let model_files = match args.model_file {
|
|
||||||
Some(m) => vec![m.into()],
|
|
||||||
None => match args.which {
|
|
||||||
Which::ThreeB0_1Ft => {
|
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
let config = match args.config_file {
|
|
||||||
Some(m) => m.into(),
|
|
||||||
None => repo.get("config.json")?,
|
|
||||||
};
|
|
||||||
let tokenizer = match args.tokenizer_file {
|
|
||||||
Some(m) => m.into(),
|
|
||||||
None => repo.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let dtype = device.bf16_default_to_f32();
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device)? };
|
|
||||||
let config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
|
|
||||||
let config = config.into_config(args.use_flash_attn);
|
|
||||||
let model = Llama::load(vb, &config)?;
|
|
||||||
let logits_processor = {
|
|
||||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
|
||||||
let temperature = args.temperature;
|
|
||||||
let sampling = if temperature <= 0. {
|
|
||||||
Sampling::ArgMax
|
|
||||||
} else {
|
|
||||||
match (args.top_k.as_ref(), args.top_p.as_ref()) {
|
|
||||||
(None, None) => Sampling::All { temperature },
|
|
||||||
(Some(&k), None) => Sampling::TopK { k, temperature },
|
|
||||||
(None, Some(&p)) => Sampling::TopP { p, temperature },
|
|
||||||
(Some(&k), Some(&p)) => Sampling::TopKThenTopP { k, p, temperature },
|
|
||||||
}
|
|
||||||
};
|
|
||||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
let cache = Cache::new(true, dtype, &config, &device)?;
|
|
||||||
let snac = load_snac(&device)?;
|
|
||||||
Ok(Self {
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
logits_processor,
|
|
||||||
cache,
|
|
||||||
device,
|
|
||||||
verbose_prompt: args.verbose_prompt,
|
|
||||||
snac,
|
|
||||||
voice: args.voice,
|
|
||||||
out_file: args.out_file,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str) -> Result<()> {
|
|
||||||
println!("running the model on '{}'", prompt);
|
|
||||||
let device = &self.device;
|
|
||||||
let prompt = format!("{voice}: {prompt}", voice = self.voice.as_str());
|
|
||||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
|
||||||
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82
|
|
||||||
let mut tokens = [
|
|
||||||
&[128259],
|
|
||||||
tokens.get_ids(),
|
|
||||||
&[128009, 128260, 128261, 128257],
|
|
||||||
]
|
|
||||||
.concat();
|
|
||||||
if self.verbose_prompt {
|
|
||||||
println!("{:?}", tokens);
|
|
||||||
}
|
|
||||||
let mut cache = self.cache.clone();
|
|
||||||
|
|
||||||
println!("starting the inference loop");
|
|
||||||
let mut index_pos = 0;
|
|
||||||
let mut audio_tokens = vec![];
|
|
||||||
for index in 0..2000 {
|
|
||||||
let (context_size, context_index) = if index > 0 {
|
|
||||||
(1, index_pos)
|
|
||||||
} else {
|
|
||||||
(tokens.len(), 0)
|
|
||||||
};
|
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
|
||||||
let input = Tensor::new(ctxt, device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, context_index, &mut cache)?;
|
|
||||||
let logits = logits.squeeze(0)?;
|
|
||||||
index_pos += ctxt.len();
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
if let Some(tok) = self.tokenizer.id_to_token(next_token) {
|
|
||||||
match tok.strip_prefix("<custom_token_") {
|
|
||||||
Some(tok) => match tok.strip_suffix('>') {
|
|
||||||
Some(tok) => {
|
|
||||||
let tok = tok.parse::<u32>()?;
|
|
||||||
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63
|
|
||||||
let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096);
|
|
||||||
audio_tokens.push(tok);
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
println!("{index}: unexpected custom token {next_token} {tok}");
|
|
||||||
}
|
|
||||||
},
|
|
||||||
None => {
|
|
||||||
println!("{index}: unexpected token {next_token} {tok}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if next_token == STOP_TOKEN_ID {
|
|
||||||
println!("reached stop token");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
tokens.push(next_token);
|
|
||||||
}
|
|
||||||
println!("generated {} audio tokens", audio_tokens.len());
|
|
||||||
let mut codes0 = vec![];
|
|
||||||
let mut codes1 = vec![];
|
|
||||||
let mut codes2 = vec![];
|
|
||||||
for audio_tokens in audio_tokens.chunks_exact(7) {
|
|
||||||
codes0.push(audio_tokens[0]);
|
|
||||||
for i in [1, 4] {
|
|
||||||
codes1.push(audio_tokens[i]);
|
|
||||||
}
|
|
||||||
for i in [2, 3, 5, 6] {
|
|
||||||
codes2.push(audio_tokens[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?;
|
|
||||||
let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?;
|
|
||||||
let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?;
|
|
||||||
let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?;
|
|
||||||
println!("decoded to pcm {pcm:?}");
|
|
||||||
let mut output = std::fs::File::create(&self.out_file)?;
|
|
||||||
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
|
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24000)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,275 +0,0 @@
|
|||||||
use anyhow::{Context, Result};
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
pub const SAMPLE_RATE: usize = 24_000;
|
|
||||||
|
|
||||||
pub(crate) struct AudioOutputData_ {
|
|
||||||
resampled_data: std::collections::VecDeque<f32>,
|
|
||||||
resampler: rubato::FastFixedIn<f32>,
|
|
||||||
output_buffer: Vec<f32>,
|
|
||||||
input_buffer: Vec<f32>,
|
|
||||||
input_len: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AudioOutputData_ {
|
|
||||||
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
|
||||||
use rubato::Resampler;
|
|
||||||
|
|
||||||
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
|
||||||
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
|
||||||
let resampler = rubato::FastFixedIn::new(
|
|
||||||
resample_ratio,
|
|
||||||
f64::max(resample_ratio, 1.0),
|
|
||||||
rubato::PolynomialDegree::Septic,
|
|
||||||
1024,
|
|
||||||
1,
|
|
||||||
)?;
|
|
||||||
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
|
||||||
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
|
||||||
Ok(Self {
|
|
||||||
resampled_data,
|
|
||||||
resampler,
|
|
||||||
input_buffer,
|
|
||||||
output_buffer,
|
|
||||||
input_len: 0,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reset(&mut self) {
|
|
||||||
use rubato::Resampler;
|
|
||||||
self.output_buffer.fill(0.);
|
|
||||||
self.input_buffer.fill(0.);
|
|
||||||
self.resampler.reset();
|
|
||||||
self.resampled_data.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
|
||||||
let mut data = Vec::with_capacity(self.resampled_data.len());
|
|
||||||
while let Some(elem) = self.resampled_data.pop_back() {
|
|
||||||
data.push(elem);
|
|
||||||
}
|
|
||||||
data
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn is_empty(&self) -> bool {
|
|
||||||
self.resampled_data.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assumes that the input buffer is large enough.
|
|
||||||
fn push_input_buffer(&mut self, samples: &[f32]) {
|
|
||||||
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
|
||||||
self.input_len += samples.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
|
||||||
use rubato::Resampler;
|
|
||||||
|
|
||||||
let mut pos_in = 0;
|
|
||||||
loop {
|
|
||||||
let rem = self.input_buffer.len() - self.input_len;
|
|
||||||
let pos_end = usize::min(pos_in + rem, samples.len());
|
|
||||||
self.push_input_buffer(&samples[pos_in..pos_end]);
|
|
||||||
pos_in = pos_end;
|
|
||||||
if self.input_len < self.input_buffer.len() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let (_, out_len) = self.resampler.process_into_buffer(
|
|
||||||
&[&self.input_buffer],
|
|
||||||
&mut [&mut self.output_buffer],
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
for &elem in self.output_buffer[..out_len].iter() {
|
|
||||||
self.resampled_data.push_front(elem)
|
|
||||||
}
|
|
||||||
self.input_len = 0;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
|
||||||
|
|
||||||
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
|
||||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
|
||||||
|
|
||||||
println!("Setup audio output stream!");
|
|
||||||
let host = cpal::default_host();
|
|
||||||
let device = host
|
|
||||||
.default_output_device()
|
|
||||||
.context("no output device available")?;
|
|
||||||
let mut supported_configs_range = device.supported_output_configs()?;
|
|
||||||
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
|
||||||
// On macOS, it's commonly the case that there are only stereo outputs.
|
|
||||||
None => device
|
|
||||||
.supported_output_configs()?
|
|
||||||
.next()
|
|
||||||
.context("no audio output available")?,
|
|
||||||
Some(config_range) => config_range,
|
|
||||||
};
|
|
||||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
|
||||||
config_range.min_sample_rate(),
|
|
||||||
config_range.max_sample_rate(),
|
|
||||||
);
|
|
||||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
|
||||||
let channels = config.channels as usize;
|
|
||||||
println!(
|
|
||||||
"cpal device: {} {} {config:?}",
|
|
||||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
|
||||||
config.sample_rate.0
|
|
||||||
);
|
|
||||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
|
||||||
SAMPLE_RATE,
|
|
||||||
config.sample_rate.0 as usize,
|
|
||||||
)?));
|
|
||||||
let ad = audio_data.clone();
|
|
||||||
let stream = device.build_output_stream(
|
|
||||||
&config,
|
|
||||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
|
||||||
data.fill(0.);
|
|
||||||
let mut ad = ad.lock().unwrap();
|
|
||||||
let mut last_elem = 0f32;
|
|
||||||
for (idx, elem) in data.iter_mut().enumerate() {
|
|
||||||
if idx % channels == 0 {
|
|
||||||
match ad.resampled_data.pop_back() {
|
|
||||||
None => break,
|
|
||||||
Some(v) => {
|
|
||||||
last_elem = v;
|
|
||||||
*elem = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
*elem = last_elem
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
move |err| eprintln!("cpal error: {err}"),
|
|
||||||
None, // None=blocking, Some(Duration)=timeout
|
|
||||||
)?;
|
|
||||||
stream.play()?;
|
|
||||||
Ok((stream, audio_data))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
|
||||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
|
||||||
|
|
||||||
println!("Setup audio input stream!");
|
|
||||||
let host = cpal::default_host();
|
|
||||||
let device = host
|
|
||||||
.default_input_device()
|
|
||||||
.context("no input device available")?;
|
|
||||||
let mut supported_configs_range = device.supported_input_configs()?;
|
|
||||||
let config_range = supported_configs_range
|
|
||||||
.find(|c| c.channels() == 1)
|
|
||||||
.context("no audio input available")?;
|
|
||||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
|
||||||
config_range.min_sample_rate(),
|
|
||||||
config_range.max_sample_rate(),
|
|
||||||
);
|
|
||||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
|
||||||
println!(
|
|
||||||
"cpal device: {} {} {config:?}",
|
|
||||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
|
||||||
config.sample_rate.0
|
|
||||||
);
|
|
||||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
|
||||||
config.sample_rate.0 as usize,
|
|
||||||
SAMPLE_RATE,
|
|
||||||
)?));
|
|
||||||
let ad = audio_data.clone();
|
|
||||||
let stream = device.build_input_stream(
|
|
||||||
&config,
|
|
||||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
|
||||||
let mut ad = ad.lock().unwrap();
|
|
||||||
if let Err(err) = ad.push_samples(data) {
|
|
||||||
eprintln!("error processing audio input {err:?}")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
move |err| eprintln!("cpal error: {err}"),
|
|
||||||
None, // None=blocking, Some(Duration)=timeout
|
|
||||||
)?;
|
|
||||||
stream.play()?;
|
|
||||||
Ok((stream, audio_data))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
|
||||||
where
|
|
||||||
T: symphonia::core::sample::Sample,
|
|
||||||
f32: symphonia::core::conv::FromSample<T>,
|
|
||||||
{
|
|
||||||
use symphonia::core::audio::Signal;
|
|
||||||
use symphonia::core::conv::FromSample;
|
|
||||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
|
||||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
|
||||||
|
|
||||||
let src = std::fs::File::open(path)?;
|
|
||||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
|
||||||
let hint = symphonia::core::probe::Hint::new();
|
|
||||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
|
||||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
|
||||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
|
||||||
let mut format = probed.format;
|
|
||||||
let track = format
|
|
||||||
.tracks()
|
|
||||||
.iter()
|
|
||||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
|
||||||
.expect("no supported audio tracks");
|
|
||||||
let mut decoder = symphonia::default::get_codecs()
|
|
||||||
.make(&track.codec_params, &Default::default())
|
|
||||||
.expect("unsupported codec");
|
|
||||||
let track_id = track.id;
|
|
||||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
|
||||||
let mut pcm_data = Vec::new();
|
|
||||||
while let Ok(packet) = format.next_packet() {
|
|
||||||
while !format.metadata().is_latest() {
|
|
||||||
format.metadata().pop();
|
|
||||||
}
|
|
||||||
if packet.track_id() != track_id {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
match decoder.decode(&packet)? {
|
|
||||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
|
||||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok((pcm_data, sample_rate))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
|
|
||||||
use rubato::Resampler;
|
|
||||||
|
|
||||||
let mut pcm_out =
|
|
||||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
|
||||||
|
|
||||||
let mut resampler =
|
|
||||||
rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)?;
|
|
||||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
|
||||||
let mut pos_in = 0;
|
|
||||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
|
||||||
let (in_len, out_len) =
|
|
||||||
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
|
||||||
pos_in += in_len;
|
|
||||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if pos_in < pcm_in.len() {
|
|
||||||
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
|
||||||
Some(&[&pcm_in[pos_in..]]),
|
|
||||||
&mut output_buffer,
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(pcm_out)
|
|
||||||
}
|
|
@ -1,197 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use candle::{DType, IndexOp, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::models::snac::{Config, Model};
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
use hf_hub::api::sync::Api;
|
|
||||||
|
|
||||||
mod audio_io;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
|
||||||
enum Action {
|
|
||||||
AudioToAudio,
|
|
||||||
AudioToCode,
|
|
||||||
CodeToAudio,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
#[value(name = "24khz")]
|
|
||||||
S24khz,
|
|
||||||
#[value(name = "32khz")]
|
|
||||||
S32khz,
|
|
||||||
#[value(name = "44khz")]
|
|
||||||
S44khz,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn sample_rate(&self) -> u32 {
|
|
||||||
match self {
|
|
||||||
Which::S24khz => 24000,
|
|
||||||
Which::S32khz => 32000,
|
|
||||||
Which::S44khz => 44000,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config_repo(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Which::S24khz => "hubertsiuzdak/snac_24khz",
|
|
||||||
Which::S32khz => "hubertsiuzdak/snac_32khz",
|
|
||||||
Which::S44khz => "hubertsiuzdak/snac_44khz",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn model_file(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Which::S24khz => "snac_24khz.safetensors",
|
|
||||||
Which::S32khz => "snac_32khz.safetensors",
|
|
||||||
Which::S44khz => "snac_44khz.safetensors",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// The action to be performed, specifies the format for the input and output data.
|
|
||||||
action: Action,
|
|
||||||
|
|
||||||
/// The input file, either an audio file or some snac tokens stored as safetensors.
|
|
||||||
in_file: String,
|
|
||||||
|
|
||||||
/// The output file, either a wave audio file or some snac tokens stored as safetensors.
|
|
||||||
out_file: String,
|
|
||||||
|
|
||||||
/// The model size to use.
|
|
||||||
#[arg(long, default_value = "24khz")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// The model weight file, in safetensor format.
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
/// The config file, in safetensor format.
|
|
||||||
#[arg(long)]
|
|
||||||
config: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let model_sample_rate = args.which.sample_rate();
|
|
||||||
let config = match args.config {
|
|
||||||
Some(c) => std::path::PathBuf::from(c),
|
|
||||||
None => Api::new()?
|
|
||||||
.model(args.which.config_repo().to_string())
|
|
||||||
.get("config.json")?,
|
|
||||||
};
|
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config)?)?;
|
|
||||||
let model = match args.model {
|
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
|
||||||
None => Api::new()?
|
|
||||||
.model("lmz/candle-snac".to_string())
|
|
||||||
.get(args.which.model_file())?,
|
|
||||||
};
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
|
||||||
let model = Model::new(&config, vb)?;
|
|
||||||
|
|
||||||
let codes = match args.action {
|
|
||||||
Action::CodeToAudio => {
|
|
||||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
|
||||||
let num_codebooks = model.num_codebooks();
|
|
||||||
(0..num_codebooks)
|
|
||||||
.map(|i| {
|
|
||||||
codes
|
|
||||||
.get(&format!("codes-{i}"))
|
|
||||||
.expect("no codes in input file")
|
|
||||||
.clone()
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
}
|
|
||||||
Action::AudioToCode | Action::AudioToAudio => {
|
|
||||||
let pcm = if args.in_file == "-" {
|
|
||||||
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
|
||||||
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
|
||||||
let mut pcms = vec![];
|
|
||||||
let stdin = std::thread::spawn(|| {
|
|
||||||
let mut s = String::new();
|
|
||||||
std::io::stdin().read_line(&mut s)
|
|
||||||
});
|
|
||||||
while !stdin.is_finished() {
|
|
||||||
let input = input_audio.lock().unwrap().take_all();
|
|
||||||
if input.is_empty() {
|
|
||||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
pcms.push(input)
|
|
||||||
}
|
|
||||||
drop(stream);
|
|
||||||
pcms.concat()
|
|
||||||
} else {
|
|
||||||
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
|
||||||
if sample_rate != model_sample_rate {
|
|
||||||
println!("WARNING: snac uses a {model_sample_rate} sample rate, input uses {sample_rate}, resampling...");
|
|
||||||
audio_io::resample(&pcm, sample_rate, model_sample_rate)?
|
|
||||||
} else {
|
|
||||||
pcm
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let pcm_len = pcm.len();
|
|
||||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
|
||||||
println!("input pcm shape: {:?}", pcm.shape());
|
|
||||||
model.encode(&pcm)?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for codes in codes.iter() {
|
|
||||||
println!("codes shape: {:?}", codes.shape());
|
|
||||||
}
|
|
||||||
|
|
||||||
match args.action {
|
|
||||||
Action::AudioToCode => {
|
|
||||||
let mut tensors = std::collections::HashMap::new();
|
|
||||||
for (i, codes) in codes.iter().enumerate() {
|
|
||||||
tensors.insert(format!("codes-{i}"), codes.clone());
|
|
||||||
}
|
|
||||||
candle::safetensors::save(&tensors, "codes.safetensors")?;
|
|
||||||
}
|
|
||||||
Action::AudioToAudio | Action::CodeToAudio => {
|
|
||||||
let codes = codes.iter().collect::<Vec<_>>();
|
|
||||||
let pcm = model.decode(&codes)?;
|
|
||||||
println!("output pcm shape: {:?}", pcm.shape());
|
|
||||||
let pcm = pcm.i(0)?.i(0)?;
|
|
||||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, model_sample_rate, true)?;
|
|
||||||
let pcm = pcm.to_vec1::<f32>()?;
|
|
||||||
if args.out_file == "-" {
|
|
||||||
let (stream, ad) = audio_io::setup_output_stream()?;
|
|
||||||
{
|
|
||||||
let mut ad = ad.lock().unwrap();
|
|
||||||
ad.push_samples(&pcm)?;
|
|
||||||
}
|
|
||||||
loop {
|
|
||||||
let ad = ad.lock().unwrap();
|
|
||||||
if ad.is_empty() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
// That's very weird, calling thread::sleep here triggers the stream to stop
|
|
||||||
// playing (the callback doesn't seem to be called anymore).
|
|
||||||
// std::thread::sleep(std::time::Duration::from_millis(100));
|
|
||||||
}
|
|
||||||
drop(stream)
|
|
||||||
} else {
|
|
||||||
let mut output = std::fs::File::create(&args.out_file)?;
|
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, model_sample_rate)?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -133,7 +133,6 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
|
|||||||
padding,
|
padding,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv = if bias {
|
let conv = if bias {
|
||||||
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?
|
||||||
|
@ -92,7 +92,6 @@ impl ConvBlock {
|
|||||||
stride,
|
stride,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,17 +11,14 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.3" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
bindgen_cuda = "0.1.1"
|
bindgen_cuda = "0.1.1"
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
candle-nn = { path = "../candle-nn", features = ["cuda"] }
|
||||||
|
|
||||||
[features]
|
|
||||||
default = []
|
|
||||||
cudnn = ["candle/cudnn"]
|
|
||||||
|
@ -2,6 +2,7 @@ mod ffi;
|
|||||||
|
|
||||||
use candle::backend::BackendStorage;
|
use candle::backend::BackendStorage;
|
||||||
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
use candle::cuda_backend::cudarc::driver::DevicePtr;
|
||||||
|
use candle::cuda_backend::WrapErr;
|
||||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
@ -141,8 +142,10 @@ impl FlashAttn {
|
|||||||
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||||
|
|
||||||
let elem_count = out_shape.elem_count();
|
let elem_count = out_shape.elem_count();
|
||||||
let dst = unsafe { dev.alloc::<T>(elem_count)? };
|
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
|
||||||
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)?;
|
let softmax_lse = dev
|
||||||
|
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
|
||||||
|
.w()?;
|
||||||
|
|
||||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
@ -604,8 +607,8 @@ impl FlashAttnVarLen {
|
|||||||
let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128);
|
let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128);
|
||||||
|
|
||||||
let elem_count = out_shape.elem_count();
|
let elem_count = out_shape.elem_count();
|
||||||
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
|
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
|
||||||
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q)?;
|
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
|
||||||
|
|
||||||
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
let is_bf16 = if is_bf16 { 1 } else { 0 };
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -53,7 +53,7 @@ __device__ void conv1d(
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void im2col1d(
|
__device__ void im2col1d(
|
||||||
const size_t numel,
|
const size_t dst_numel,
|
||||||
const size_t l_out,
|
const size_t l_out,
|
||||||
const size_t l_k,
|
const size_t l_k,
|
||||||
const size_t stride,
|
const size_t stride,
|
||||||
@ -63,10 +63,10 @@ __device__ void im2col1d(
|
|||||||
const T *src,
|
const T *src,
|
||||||
T *dst
|
T *dst
|
||||||
) {
|
) {
|
||||||
const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
// dst: (b_size, l_out, c_in, l_k)
|
// dst: (b_size, l_out, c_in, l_k)
|
||||||
// src: (b_size, c_in, l_in)
|
// src: (b_size, c_in, l_in)
|
||||||
if (thread_i >= numel) {
|
if (dst_i >= dst_numel) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const size_t *src_dims = info;
|
const size_t *src_dims = info;
|
||||||
@ -74,26 +74,26 @@ __device__ void im2col1d(
|
|||||||
const size_t c_in = src_dims[1];
|
const size_t c_in = src_dims[1];
|
||||||
const size_t l_in = src_dims[2];
|
const size_t l_in = src_dims[2];
|
||||||
|
|
||||||
const size_t dst_s1 = c_in;
|
const size_t dst_s2 = l_k;
|
||||||
|
const size_t dst_s1 = c_in * dst_s2;
|
||||||
const size_t dst_s0 = l_out * dst_s1;
|
const size_t dst_s0 = l_out * dst_s1;
|
||||||
|
|
||||||
size_t tmp_dst_i = thread_i;
|
size_t tmp_dst_i = dst_i;
|
||||||
const size_t b_idx = tmp_dst_i / dst_s0;
|
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||||
tmp_dst_i -= b_idx * dst_s0;
|
tmp_dst_i -= b_idx * dst_s0;
|
||||||
const size_t l_idx = tmp_dst_i / dst_s1;
|
const size_t l_idx = tmp_dst_i / dst_s1;
|
||||||
tmp_dst_i -= l_idx * dst_s1;
|
tmp_dst_i -= l_idx * dst_s1;
|
||||||
const size_t c_idx = tmp_dst_i;
|
const size_t c_idx = tmp_dst_i / dst_s2;
|
||||||
for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) {
|
tmp_dst_i -= c_idx * dst_s2;
|
||||||
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
const size_t l_k_idx = tmp_dst_i;
|
||||||
size_t dst_i = thread_i * l_k + l_k_idx;
|
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
|
||||||
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
|
||||||
dst[dst_i] = static_cast<T>(0);
|
dst[dst_i] = static_cast<T>(0);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
src_l_idx -= padding;
|
src_l_idx -= padding;
|
||||||
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
|
||||||
dst[dst_i] = src[src_i];
|
dst[dst_i] = src[src_i];
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
|
@ -33,7 +33,6 @@ criterion = { workspace = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
cudnn = ["candle/cudnn"]
|
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||||
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
|
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
//! Convolution Layers.
|
//! Convolution Layers.
|
||||||
use crate::BatchNorm;
|
use crate::BatchNorm;
|
||||||
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub struct Conv1dConfig {
|
pub struct Conv1dConfig {
|
||||||
@ -8,7 +8,6 @@ pub struct Conv1dConfig {
|
|||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub dilation: usize,
|
pub dilation: usize,
|
||||||
pub groups: usize,
|
pub groups: usize,
|
||||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Conv1dConfig {
|
impl Default for Conv1dConfig {
|
||||||
@ -18,7 +17,6 @@ impl Default for Conv1dConfig {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -54,13 +52,12 @@ impl Conv1d {
|
|||||||
|
|
||||||
impl crate::Module for Conv1d {
|
impl crate::Module for Conv1d {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.conv1d_with_algo(
|
let x = x.conv1d(
|
||||||
&self.weight,
|
&self.weight,
|
||||||
self.config.padding,
|
self.config.padding,
|
||||||
self.config.stride,
|
self.config.stride,
|
||||||
self.config.dilation,
|
self.config.dilation,
|
||||||
self.config.groups,
|
self.config.groups,
|
||||||
self.config.cudnn_fwd_algo,
|
|
||||||
)?;
|
)?;
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
@ -150,7 +147,6 @@ pub struct Conv2dConfig {
|
|||||||
pub stride: usize,
|
pub stride: usize,
|
||||||
pub dilation: usize,
|
pub dilation: usize,
|
||||||
pub groups: usize,
|
pub groups: usize,
|
||||||
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Conv2dConfig {
|
impl Default for Conv2dConfig {
|
||||||
@ -160,7 +156,6 @@ impl Default for Conv2dConfig {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -216,13 +211,12 @@ impl Conv2d {
|
|||||||
|
|
||||||
impl crate::Module for Conv2d {
|
impl crate::Module for Conv2d {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.conv2d_with_algo(
|
let x = x.conv2d(
|
||||||
&self.weight,
|
&self.weight,
|
||||||
self.config.padding,
|
self.config.padding,
|
||||||
self.config.stride,
|
self.config.stride,
|
||||||
self.config.dilation,
|
self.config.dilation,
|
||||||
self.config.groups,
|
self.config.groups,
|
||||||
self.config.cudnn_fwd_algo,
|
|
||||||
)?;
|
)?;
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
|
@ -31,7 +31,6 @@ pub mod ops;
|
|||||||
pub mod optim;
|
pub mod optim;
|
||||||
pub mod rnn;
|
pub mod rnn;
|
||||||
pub mod rotary_emb;
|
pub mod rotary_emb;
|
||||||
pub mod sampling;
|
|
||||||
pub mod sequential;
|
pub mod sequential;
|
||||||
pub mod var_builder;
|
pub mod var_builder;
|
||||||
pub mod var_map;
|
pub mod var_map;
|
||||||
|
@ -41,36 +41,12 @@ impl Linear {
|
|||||||
|
|
||||||
impl super::Module for Linear {
|
impl super::Module for Linear {
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
// When possible, we avoid using a broadcasted matmul as it is much slower
|
let w = match *x.dims() {
|
||||||
// than the standard matmul for the cuda and cpu backends.
|
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
|
||||||
let x = match *x.dims() {
|
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||||
[b1, b2, m, k] => {
|
_ => self.weight.t()?,
|
||||||
if x.is_contiguous() {
|
|
||||||
let w = self.weight.t()?;
|
|
||||||
x.reshape((b1 * b2 * m, k))?
|
|
||||||
.matmul(&w)?
|
|
||||||
.reshape((b1, b2, m, ()))?
|
|
||||||
} else {
|
|
||||||
let w = self.weight.broadcast_left((b1, b2))?.t()?;
|
|
||||||
x.matmul(&w)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
[bsize, m, k] => {
|
|
||||||
if x.is_contiguous() {
|
|
||||||
let w = self.weight.t()?;
|
|
||||||
x.reshape((bsize * m, k))?
|
|
||||||
.matmul(&w)?
|
|
||||||
.reshape((bsize, m, ()))?
|
|
||||||
} else {
|
|
||||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
|
||||||
x.matmul(&w)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
let w = self.weight.t()?;
|
|
||||||
x.matmul(&w)?
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
Some(bias) => x.broadcast_add(bias),
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
@ -112,7 +112,7 @@ impl candle::CustomOp1 for Sigmoid {
|
|||||||
let src = &src.slice(layout.start_offset()..);
|
let src = &src.slice(layout.start_offset()..);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let out = unsafe { dev.alloc::<T>(el_count)? };
|
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||||
|
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
candle::builder_arg!(builder, el_count, dims.len());
|
candle::builder_arg!(builder, el_count, dims.len());
|
||||||
@ -373,7 +373,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
builder.arg(&src);
|
builder.arg(&src);
|
||||||
builder.arg(&dst);
|
builder.arg(&dst);
|
||||||
@ -561,7 +561,7 @@ impl candle::CustomOp2 for RmsNorm {
|
|||||||
};
|
};
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
builder.arg(&src);
|
builder.arg(&src);
|
||||||
builder.arg(&dst);
|
builder.arg(&dst);
|
||||||
@ -800,7 +800,7 @@ impl candle::CustomOp3 for LayerNorm {
|
|||||||
let func =
|
let func =
|
||||||
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
builder.arg(&src);
|
builder.arg(&src);
|
||||||
builder.arg(&dst);
|
builder.arg(&dst);
|
||||||
|
@ -119,7 +119,7 @@ impl candle::CustomOp3 for RotaryEmbI {
|
|||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
builder.arg(&src);
|
builder.arg(&src);
|
||||||
builder.arg(&cos);
|
builder.arg(&cos);
|
||||||
@ -369,7 +369,7 @@ impl candle::CustomOp3 for RotaryEmb {
|
|||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
builder.arg(&src);
|
builder.arg(&src);
|
||||||
builder.arg(&cos);
|
builder.arg(&cos);
|
||||||
@ -620,7 +620,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
|
|||||||
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
|
||||||
// SAFETY: Set later by running the kernel.
|
// SAFETY: Set later by running the kernel.
|
||||||
let dst = unsafe { dev.alloc::<T>(el)? };
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
builder.arg(&src);
|
builder.arg(&src);
|
||||||
builder.arg(&cos);
|
builder.arg(&cos);
|
||||||
|
@ -1,20 +0,0 @@
|
|||||||
use candle::{Result, Tensor};
|
|
||||||
|
|
||||||
/// Sample according to the Gumbel-Softmax distribution.
|
|
||||||
pub fn gumbel_softmax<D: candle::shape::Dim>(
|
|
||||||
logits: &Tensor,
|
|
||||||
temperature: f64,
|
|
||||||
dim: D,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
if temperature <= 0.0 {
|
|
||||||
logits.argmax(dim)
|
|
||||||
} else if temperature == 1.0 {
|
|
||||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
|
||||||
let sampled = (logits - minus_g)?.argmax(dim)?;
|
|
||||||
Ok(sampled)
|
|
||||||
} else {
|
|
||||||
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
|
|
||||||
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
|
|
||||||
Ok(sampled)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.9.0-alpha.3"
|
version = "0.9.0-alpha.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
|
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.3" }
|
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.1" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -29,7 +29,6 @@ tracing = { workspace = true }
|
|||||||
default = []
|
default = []
|
||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||||
cudnn = ["candle/cudnn", "candle-nn/cudnn"]
|
|
||||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
|
@ -13,8 +13,6 @@ pub enum Sampling {
|
|||||||
TopK { k: usize, temperature: f64 },
|
TopK { k: usize, temperature: f64 },
|
||||||
TopP { p: f64, temperature: f64 },
|
TopP { p: f64, temperature: f64 },
|
||||||
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
||||||
// Note that the rng is not used for the Gumbel-Softmax sampling.
|
|
||||||
GumbelSoftmax { temperature: f64 },
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitsProcessor {
|
pub struct LogitsProcessor {
|
||||||
@ -51,11 +49,6 @@ impl LogitsProcessor {
|
|||||||
Ok(next_token)
|
Ok(next_token)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
|
|
||||||
let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
|
|
||||||
sampled.to_vec0::<u32>()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
|
||||||
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
|
||||||
let next_token = distr.sample(&mut self.rng) as u32;
|
let next_token = distr.sample(&mut self.rng) as u32;
|
||||||
@ -134,9 +127,6 @@ impl LogitsProcessor {
|
|||||||
|
|
||||||
let next_token = match &self.sampling {
|
let next_token = match &self.sampling {
|
||||||
Sampling::ArgMax => self.sample_argmax(logits)?,
|
Sampling::ArgMax => self.sample_argmax(logits)?,
|
||||||
Sampling::GumbelSoftmax { temperature } => {
|
|
||||||
self.sample_gumbel_softmax(&logits, *temperature)?
|
|
||||||
}
|
|
||||||
Sampling::All { temperature } => {
|
Sampling::All { temperature } => {
|
||||||
let prs = prs(*temperature)?;
|
let prs = prs(*temperature)?;
|
||||||
self.sample_multinomial(&prs)?
|
self.sample_multinomial(&prs)?
|
||||||
|
@ -504,9 +504,8 @@ impl BertModel {
|
|||||||
Some(attention_mask) => attention_mask.clone(),
|
Some(attention_mask) => attention_mask.clone(),
|
||||||
None => input_ids.ones_like()?,
|
None => input_ids.ones_like()?,
|
||||||
};
|
};
|
||||||
let dtype = embedding_output.dtype();
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||||
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
||||||
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||||
Ok(sequence_output)
|
Ok(sequence_output)
|
||||||
}
|
}
|
||||||
@ -520,11 +519,8 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
|||||||
};
|
};
|
||||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||||
// torch.finfo(dtype).min
|
// torch.finfo(dtype).min
|
||||||
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
(attention_mask.ones_like()? - &attention_mask)?
|
||||||
&Tensor::try_from(f32::MIN)?
|
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||||
.to_device(attention_mask.device())?
|
|
||||||
.to_dtype(dtype)?,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
||||||
|
@ -514,9 +514,8 @@ impl ChineseClipTextTransformer {
|
|||||||
Some(attention_mask) => attention_mask.clone(),
|
Some(attention_mask) => attention_mask.clone(),
|
||||||
None => input_ids.ones_like()?,
|
None => input_ids.ones_like()?,
|
||||||
};
|
};
|
||||||
let dtype = embedding_output.dtype();
|
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||||
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
||||||
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
|
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||||
let encoder_output = encoder_outputs.i((.., 0, ..))?;
|
let encoder_output = encoder_outputs.i((.., 0, ..))?;
|
||||||
let pooled_output = match &self.pooler {
|
let pooled_output = match &self.pooler {
|
||||||
@ -536,9 +535,6 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
|||||||
};
|
};
|
||||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||||
// torch.finfo(dtype).min
|
// torch.finfo(dtype).min
|
||||||
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
(attention_mask.ones_like()? - &attention_mask)?
|
||||||
&Tensor::try_from(f32::MIN)?
|
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||||
.to_device(attention_mask.device())?
|
|
||||||
.to_dtype(dtype)?,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
@ -498,36 +498,4 @@ impl Model {
|
|||||||
}
|
}
|
||||||
Ok(all_samples)
|
Ok(all_samples)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn audio_tokens_and_mask(&self, mut frame: Vec<u32>) -> Result<(Tensor, Tensor)> {
|
|
||||||
let cb = self.config.audio_num_codebooks;
|
|
||||||
let device = &self.backbone.device;
|
|
||||||
let mut mask = vec![1u8; cb];
|
|
||||||
mask.push(0);
|
|
||||||
let mask = Tensor::from_vec(mask, (1, 1, cb + 1), device)?;
|
|
||||||
|
|
||||||
frame.push(0);
|
|
||||||
let tokens = Tensor::from_vec(frame, (1, 1, cb + 1), device)?;
|
|
||||||
Ok((tokens, mask))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn text_tokens_and_mask(&self, ids: &[u32]) -> Result<(Tensor, Tensor)> {
|
|
||||||
let cb = self.config.audio_num_codebooks;
|
|
||||||
let device = &self.backbone.device;
|
|
||||||
let mut tokens = vec![];
|
|
||||||
let mut mask = vec![];
|
|
||||||
for &v in ids.iter() {
|
|
||||||
let mut token = vec![0; cb];
|
|
||||||
token.push(v);
|
|
||||||
let token = Tensor::from_vec(token, (1, 1, cb + 1), device)?;
|
|
||||||
tokens.push(token);
|
|
||||||
let mut m = vec![0u8; cb];
|
|
||||||
m.push(1);
|
|
||||||
let m = Tensor::from_vec(m, (1, 1, cb + 1), device)?;
|
|
||||||
mask.push(m);
|
|
||||||
}
|
|
||||||
let tokens = Tensor::cat(&tokens, 1)?;
|
|
||||||
let mask = Tensor::cat(&mask, 1)?;
|
|
||||||
Ok((tokens, mask))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ impl EncoderBlock {
|
|||||||
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
||||||
let cfg1 = Conv1dConfig {
|
let cfg1 = Conv1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: stride.div_ceil(2),
|
padding: (stride + 1) / 2,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
||||||
@ -196,7 +196,7 @@ impl DecoderBlock {
|
|||||||
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
||||||
let cfg = ConvTranspose1dConfig {
|
let cfg = ConvTranspose1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: stride.div_ceil(2),
|
padding: (stride + 1) / 2,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
||||||
@ -330,7 +330,6 @@ impl ResidualVectorQuantizer {
|
|||||||
Ok(Self { quantizers })
|
Ok(Self { quantizers })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::wrong_self_convention)]
|
|
||||||
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
|
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
|
||||||
let mut sum = None;
|
let mut sum = None;
|
||||||
for (idx, quantizer) in self.quantizers.iter().enumerate() {
|
for (idx, quantizer) in self.quantizers.iter().enumerate() {
|
||||||
|
@ -124,7 +124,6 @@ impl ResidualConvUnit {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv1 = conv2d(
|
let conv1 = conv2d(
|
||||||
conf.num_features,
|
conf.num_features,
|
||||||
@ -209,7 +208,6 @@ impl FeatureFusionBlock {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let output_conv = conv2d(
|
let output_conv = conv2d(
|
||||||
conf.num_features,
|
conf.num_features,
|
||||||
@ -260,7 +258,6 @@ impl Scratch {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let layer1_rn = conv2d_no_bias(
|
let layer1_rn = conv2d_no_bias(
|
||||||
@ -322,7 +319,6 @@ impl Scratch {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let output_conv1 = conv2d(
|
let output_conv1 = conv2d(
|
||||||
conf.num_features,
|
conf.num_features,
|
||||||
@ -429,7 +425,6 @@ impl DPTHead {
|
|||||||
stride: 2,
|
stride: 2,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
},
|
},
|
||||||
vb.pp("resize_layers").pp("3"),
|
vb.pp("resize_layers").pp("3"),
|
||||||
)?),
|
)?),
|
||||||
|
@ -19,7 +19,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum HiddenAct {
|
enum HiddenAct {
|
||||||
Gelu,
|
Gelu,
|
||||||
Relu,
|
Relu,
|
||||||
}
|
}
|
||||||
@ -49,22 +49,22 @@ impl Module for HiddenActLayer {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum PositionEmbeddingType {
|
enum PositionEmbeddingType {
|
||||||
#[default]
|
#[default]
|
||||||
Absolute,
|
Absolute,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub vocab_size: usize,
|
vocab_size: usize,
|
||||||
pub dim: usize,
|
dim: usize,
|
||||||
n_layers: usize,
|
n_layers: usize,
|
||||||
n_heads: usize,
|
n_heads: usize,
|
||||||
hidden_dim: usize,
|
hidden_dim: usize,
|
||||||
activation: HiddenAct,
|
activation: HiddenAct,
|
||||||
max_position_embeddings: usize,
|
max_position_embeddings: usize,
|
||||||
initializer_range: f64,
|
initializer_range: f64,
|
||||||
pub pad_token_id: usize,
|
pad_token_id: usize,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
position_embedding_type: PositionEmbeddingType,
|
position_embedding_type: PositionEmbeddingType,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@ -345,107 +345,3 @@ impl DistilBertModel {
|
|||||||
Ok(sequence_output)
|
Ok(sequence_output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct DistilBertPredictionHeadTransform {
|
|
||||||
dense: Linear,
|
|
||||||
activation: HiddenActLayer,
|
|
||||||
layer_norm: LayerNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DistilBertPredictionHeadTransform {
|
|
||||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let dense = linear(config.dim, config.dim, vb.pp("vocab_transform"))?;
|
|
||||||
let activation = HiddenActLayer::new(config.activation);
|
|
||||||
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("vocab_layer_norm"))?;
|
|
||||||
Ok(Self {
|
|
||||||
dense,
|
|
||||||
activation,
|
|
||||||
layer_norm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for DistilBertPredictionHeadTransform {
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
let hidden_states = self
|
|
||||||
.activation
|
|
||||||
.forward(&self.dense.forward(hidden_states)?)?;
|
|
||||||
self.layer_norm.forward(&hidden_states)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
|
|
||||||
pub struct DistilBertLMPredictionHead {
|
|
||||||
transform: DistilBertPredictionHeadTransform,
|
|
||||||
decoder: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DistilBertLMPredictionHead {
|
|
||||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let transform = DistilBertPredictionHeadTransform::load(vb.clone(), config)?;
|
|
||||||
|
|
||||||
// distil_bert_uncased uses the word embeddings for the vocab projector weight, but has a seperate vocab_projector bias
|
|
||||||
let vocab_projector_weight_vb = vb.pp("distilbert.embeddings.word_embeddings");
|
|
||||||
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
|
||||||
let ws = vocab_projector_weight_vb.get_with_hints(
|
|
||||||
(config.vocab_size, config.dim),
|
|
||||||
"weight",
|
|
||||||
init_ws,
|
|
||||||
)?;
|
|
||||||
let bound = 1. / (config.dim as f64).sqrt();
|
|
||||||
let init_bs = candle_nn::Init::Uniform {
|
|
||||||
lo: -bound,
|
|
||||||
up: bound,
|
|
||||||
};
|
|
||||||
|
|
||||||
let vocab_projector_bias_vb = vb.pp("vocab_projector");
|
|
||||||
let bs = vocab_projector_bias_vb.get_with_hints(config.vocab_size, "bias", init_bs)?;
|
|
||||||
|
|
||||||
let decoder = Linear::from_weights(ws, Some(bs));
|
|
||||||
|
|
||||||
Ok(Self { transform, decoder })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for DistilBertLMPredictionHead {
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
|
||||||
self.decoder
|
|
||||||
.forward(&self.transform.forward(hidden_states)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
|
|
||||||
pub struct DistilBertOnlyMLMHead {
|
|
||||||
predictions: DistilBertLMPredictionHead,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DistilBertOnlyMLMHead {
|
|
||||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let predictions = DistilBertLMPredictionHead::load(vb.clone(), config)?;
|
|
||||||
Ok(Self { predictions })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for DistilBertOnlyMLMHead {
|
|
||||||
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
|
|
||||||
self.predictions.forward(sequence_output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct DistilBertForMaskedLM {
|
|
||||||
pub bert: DistilBertModel,
|
|
||||||
cls: DistilBertOnlyMLMHead,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DistilBertForMaskedLM {
|
|
||||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
|
||||||
let bert = DistilBertModel::load(vb.pp("distilbert"), config)?;
|
|
||||||
let cls = DistilBertOnlyMLMHead::load(vb.clone(), config)?;
|
|
||||||
Ok(Self { bert, cls })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
|
||||||
let sequence_output = self.bert.forward(input_ids, attention_mask)?;
|
|
||||||
self.cls.forward(&sequence_output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -141,20 +141,6 @@ pub fn conv1d_weight_norm(
|
|||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn conv1d_weight_norm_no_bias(
|
|
||||||
in_c: usize,
|
|
||||||
out_c: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
config: candle_nn::Conv1dConfig,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
|
||||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
|
||||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
|
||||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
|
||||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
|
||||||
Ok(Conv1d::new(weight, None, config))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv_transpose1d_weight_norm(
|
pub fn conv_transpose1d_weight_norm(
|
||||||
in_c: usize,
|
in_c: usize,
|
||||||
out_c: usize,
|
out_c: usize,
|
||||||
@ -468,7 +454,6 @@ impl EncodecConv1d {
|
|||||||
stride,
|
stride,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
},
|
},
|
||||||
vb.pp("conv"),
|
vb.pp("conv"),
|
||||||
)?,
|
)?,
|
||||||
|
@ -6,8 +6,8 @@ pub fn get_noise(
|
|||||||
width: usize,
|
width: usize,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let height = height.div_ceil(16) * 2;
|
let height = (height + 15) / 16 * 2;
|
||||||
let width = width.div_ceil(16) * 2;
|
let width = (width + 15) / 16 * 2;
|
||||||
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,8 +84,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
|
|||||||
|
|
||||||
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
||||||
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
||||||
let height = height.div_ceil(16);
|
let height = (height + 15) / 16;
|
||||||
let width = width.div_ceil(16);
|
let width = (width + 15) / 16;
|
||||||
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
||||||
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
||||||
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
||||||
|
@ -27,7 +27,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
self.d_model.div_ceil(16)
|
(self.d_model + 15) / 16
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_inner(&self) -> usize {
|
fn d_inner(&self) -> usize {
|
||||||
|
@ -716,7 +716,7 @@ pub mod transformer {
|
|||||||
None => {
|
None => {
|
||||||
let hidden_dim = self.dim * 4;
|
let hidden_dim = self.dim * 4;
|
||||||
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
|
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
|
||||||
n_hidden.div_ceil(256) * 256
|
(n_hidden + 255) / 256 * 256
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -267,7 +267,6 @@ impl StreamableConv1d {
|
|||||||
stride,
|
stride,
|
||||||
dilation,
|
dilation,
|
||||||
groups,
|
groups,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
|
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
|
||||||
if k_size < stride {
|
if k_size < stride {
|
||||||
|
@ -104,7 +104,6 @@ pub mod rwkv_v6;
|
|||||||
pub mod segformer;
|
pub mod segformer;
|
||||||
pub mod segment_anything;
|
pub mod segment_anything;
|
||||||
pub mod siglip;
|
pub mod siglip;
|
||||||
pub mod snac;
|
|
||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod stable_lm;
|
pub mod stable_lm;
|
||||||
pub mod starcoder2;
|
pub mod starcoder2;
|
||||||
|
@ -1,814 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
//! Implementation of the Multi-Scale Neural Audio Codec (SNAC)
|
|
||||||
//!
|
|
||||||
//! See: [SNAC](https://github.com/hubertsiuzdak/snac)
|
|
||||||
//!
|
|
||||||
/// Multi-Scale Neural Audio Codec (SNAC) compresses audio into discrete codes at a low bitrate.
|
|
||||||
/// For more information, read the paper: https://arxiv.org/abs/2410.14411
|
|
||||||
///
|
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
|
||||||
use candle_nn::{
|
|
||||||
linear_b, Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, LayerNorm, Linear,
|
|
||||||
VarBuilder,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(serde::Deserialize, Debug, Clone)]
|
|
||||||
pub struct Config {
|
|
||||||
pub sampling_rate: usize,
|
|
||||||
pub encoder_dim: usize,
|
|
||||||
pub encoder_rates: Vec<usize>,
|
|
||||||
pub decoder_dim: usize,
|
|
||||||
pub decoder_rates: Vec<usize>,
|
|
||||||
pub attn_window_size: Option<usize>,
|
|
||||||
pub codebook_size: usize,
|
|
||||||
pub codebook_dim: usize,
|
|
||||||
pub vq_strides: Vec<usize>,
|
|
||||||
pub noise: bool,
|
|
||||||
pub depthwise: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Equivalent to torch.repeat_interleave
|
|
||||||
pub fn repeat_interleave<D: candle::shape::Dim>(
|
|
||||||
img: &Tensor,
|
|
||||||
repeats: usize,
|
|
||||||
dim: D,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
if repeats == 1 {
|
|
||||||
return Ok(img.clone());
|
|
||||||
}
|
|
||||||
let dim = dim.to_index(img.shape(), "chunk")?;
|
|
||||||
let img = img.unsqueeze(dim + 1)?;
|
|
||||||
let mut dims = img.dims().to_vec();
|
|
||||||
dims[dim + 1] = repeats;
|
|
||||||
img.broadcast_as(dims)?.flatten(dim, dim + 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv1d_weight_norm(
|
|
||||||
in_c: usize,
|
|
||||||
out_c: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
config: candle_nn::Conv1dConfig,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
|
||||||
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
|
|
||||||
let weight_v = {
|
|
||||||
let name = "parametrizations.weight.original1";
|
|
||||||
match vb.get((out_c, in_c, kernel_size), name) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
|
||||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
|
||||||
let bias = vb.get(out_c, "bias")?;
|
|
||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv1d_weight_norm_no_bias(
|
|
||||||
in_c: usize,
|
|
||||||
out_c: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
config: candle_nn::Conv1dConfig,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
|
||||||
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
|
|
||||||
let weight_v = {
|
|
||||||
let name = "parametrizations.weight.original1";
|
|
||||||
match vb.get((out_c, in_c, kernel_size), name) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
|
||||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
|
||||||
Ok(Conv1d::new(weight, None, config))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv_transpose1d_weight_norm(
|
|
||||||
in_c: usize,
|
|
||||||
out_c: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
bias: bool,
|
|
||||||
config: candle_nn::ConvTranspose1dConfig,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<ConvTranspose1d> {
|
|
||||||
let weight_g = vb.get((in_c, 1, 1), "parametrizations.weight.original0")?;
|
|
||||||
let weight_v = vb.get(
|
|
||||||
(in_c, out_c, kernel_size),
|
|
||||||
"parametrizations.weight.original1",
|
|
||||||
)?;
|
|
||||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
|
||||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
|
||||||
let bias = if bias {
|
|
||||||
Some(vb.get(out_c, "bias")?)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(ConvTranspose1d::new(weight, bias, config))
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/hubertsiuzdak/snac/blob/main/snac/attention.py
|
|
||||||
#[allow(unused)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct SinusoidalEmbeddings {
|
|
||||||
inv_freq: Tensor,
|
|
||||||
scale: Tensor,
|
|
||||||
scale_base: f32,
|
|
||||||
use_xpos: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SinusoidalEmbeddings {
|
|
||||||
fn new(dim: usize, scale_base: f32, use_xpos: bool, dev: &Device) -> Result<Self> {
|
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
|
||||||
.step_by(2)
|
|
||||||
.map(|i| 1f32 / 10_000f32.powf(i as f32 / dim as f32))
|
|
||||||
.collect();
|
|
||||||
let len = inv_freq.len();
|
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, len, dev)?.to_dtype(DType::F32)?;
|
|
||||||
let scale: Vec<_> = (0..dim)
|
|
||||||
.step_by(2)
|
|
||||||
.map(|i| (i as f32 + 0.4 * dim as f32) / (1.4 * dim as f32))
|
|
||||||
.collect();
|
|
||||||
let scale = Tensor::from_vec(scale, len, dev)?.to_dtype(DType::F32)?;
|
|
||||||
Ok(Self {
|
|
||||||
inv_freq,
|
|
||||||
scale,
|
|
||||||
scale_base,
|
|
||||||
use_xpos,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct LocalMHA {
|
|
||||||
norm: LayerNorm,
|
|
||||||
to_qkv: Linear,
|
|
||||||
to_out: Linear,
|
|
||||||
num_heads: usize,
|
|
||||||
head_dim: usize,
|
|
||||||
rel_pos: Option<SinusoidalEmbeddings>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LocalMHA {
|
|
||||||
fn new(
|
|
||||||
dim: usize,
|
|
||||||
window_size: usize,
|
|
||||||
dim_head: usize,
|
|
||||||
use_rotary_pos_emb: bool,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
|
|
||||||
let to_qkv = linear_b(dim, dim * 3, false, vb.pp("to_qkv"))?;
|
|
||||||
let to_out = linear_b(dim, dim, false, vb.pp("to_out"))?;
|
|
||||||
let rel_pos = if use_rotary_pos_emb {
|
|
||||||
let rel_pos =
|
|
||||||
SinusoidalEmbeddings::new(dim_head, window_size as f32 / 2.0, false, vb.device())?;
|
|
||||||
Some(rel_pos)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(Self {
|
|
||||||
norm,
|
|
||||||
to_qkv,
|
|
||||||
to_out,
|
|
||||||
rel_pos,
|
|
||||||
num_heads: dim / dim_head,
|
|
||||||
head_dim: dim_head,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for LocalMHA {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let (b, c, t) = xs.dims3()?;
|
|
||||||
let residual = xs.clone();
|
|
||||||
let xs = xs.transpose(1, 2)?.apply(&self.norm)?;
|
|
||||||
let qkv = xs.apply(&self.to_qkv)?;
|
|
||||||
let q = qkv.narrow(D::Minus1, 0, c)?;
|
|
||||||
let k = qkv.narrow(D::Minus1, c, c)?;
|
|
||||||
let v = qkv.narrow(D::Minus1, 2 * c, c)?;
|
|
||||||
let q = q
|
|
||||||
.reshape((b, t, self.num_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.contiguous()?;
|
|
||||||
let k = k
|
|
||||||
.reshape((b, t, self.num_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.contiguous()?;
|
|
||||||
let v = v
|
|
||||||
.reshape((b, t, self.num_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.contiguous()?;
|
|
||||||
let (q, k) = match self.rel_pos {
|
|
||||||
Some(_) => todo!(),
|
|
||||||
None => (q, k),
|
|
||||||
};
|
|
||||||
let out = {
|
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
|
||||||
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
|
||||||
// Non-causal attention
|
|
||||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
|
||||||
attn_weights.matmul(&v)?
|
|
||||||
};
|
|
||||||
let out = out
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.reshape((b, t, self.num_heads * self.head_dim))?
|
|
||||||
.apply(&self.to_out)?;
|
|
||||||
out.transpose(1, 2)? + residual
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Snake1d {
|
|
||||||
alpha: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Snake1d {
|
|
||||||
pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let alpha = vb.get((1, channels, 1), "alpha")?;
|
|
||||||
Ok(Self { alpha })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Snake1d {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let xs_shape = xs.shape();
|
|
||||||
let xs = xs.flatten_from(2)?;
|
|
||||||
let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
|
|
||||||
let sin = (&sin * &sin)?;
|
|
||||||
(xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct ResidualUnit {
|
|
||||||
snake1: Snake1d,
|
|
||||||
conv1: Conv1d,
|
|
||||||
snake2: Snake1d,
|
|
||||||
conv2: Conv1d,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResidualUnit {
|
|
||||||
fn new(
|
|
||||||
dim: usize,
|
|
||||||
dilation: usize,
|
|
||||||
kernel: usize,
|
|
||||||
groups: usize,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let pad = ((kernel - 1) * dilation) / 2;
|
|
||||||
let vb = vb.pp("block");
|
|
||||||
let snake1 = Snake1d::new(dim, vb.pp(0))?;
|
|
||||||
let cfg1 = Conv1dConfig {
|
|
||||||
dilation,
|
|
||||||
padding: pad,
|
|
||||||
groups,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let conv1 = conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
|
|
||||||
let snake2 = Snake1d::new(dim, vb.pp(2))?;
|
|
||||||
let conv2 = conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
|
|
||||||
Ok(Self {
|
|
||||||
snake1,
|
|
||||||
conv1,
|
|
||||||
snake2,
|
|
||||||
conv2,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for ResidualUnit {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let ys = xs
|
|
||||||
.apply(&self.snake1)?
|
|
||||||
.apply(&self.conv1)?
|
|
||||||
.apply(&self.snake2)?
|
|
||||||
.apply(&self.conv2)?;
|
|
||||||
let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
|
|
||||||
if pad > 0 {
|
|
||||||
&ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
|
|
||||||
} else {
|
|
||||||
ys + xs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct NoiseBlock {
|
|
||||||
linear: Conv1d,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl NoiseBlock {
|
|
||||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let linear = conv1d_weight_norm_no_bias(dim, dim, 1, Default::default(), vb.pp("linear"))?;
|
|
||||||
Ok(Self { linear })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for NoiseBlock {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let (b, _c, t) = xs.dims3()?;
|
|
||||||
let noise = Tensor::randn(0f32, 1f32, (b, 1, t), xs.device())?;
|
|
||||||
let h = xs.apply(&self.linear)?;
|
|
||||||
let n = noise.broadcast_mul(&h)?;
|
|
||||||
let xs = (xs + n)?;
|
|
||||||
Ok(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct DecoderBlock {
|
|
||||||
snake1: Snake1d,
|
|
||||||
conv_tr1: ConvTranspose1d,
|
|
||||||
noise: Option<NoiseBlock>,
|
|
||||||
res1: ResidualUnit,
|
|
||||||
res2: ResidualUnit,
|
|
||||||
res3: ResidualUnit,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DecoderBlock {
|
|
||||||
fn new(
|
|
||||||
in_dim: usize,
|
|
||||||
out_dim: usize,
|
|
||||||
stride: usize,
|
|
||||||
noise: bool,
|
|
||||||
groups: usize,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let vb = vb.pp("block");
|
|
||||||
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
|
||||||
let cfg = ConvTranspose1dConfig {
|
|
||||||
stride,
|
|
||||||
padding: stride.div_ceil(2),
|
|
||||||
output_padding: stride % 2,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let conv_tr1 =
|
|
||||||
conv_transpose1d_weight_norm(in_dim, out_dim, 2 * stride, true, cfg, vb.pp(1))?;
|
|
||||||
let (n, noise) = if noise {
|
|
||||||
let noise = NoiseBlock::new(out_dim, vb.pp(2))?;
|
|
||||||
(1, Some(noise))
|
|
||||||
} else {
|
|
||||||
(0, None)
|
|
||||||
};
|
|
||||||
let res1 = ResidualUnit::new(out_dim, 1, 7, groups, vb.pp(2 + n))?;
|
|
||||||
let res2 = ResidualUnit::new(out_dim, 3, 7, groups, vb.pp(3 + n))?;
|
|
||||||
let res3 = ResidualUnit::new(out_dim, 9, 7, groups, vb.pp(4 + n))?;
|
|
||||||
Ok(Self {
|
|
||||||
snake1,
|
|
||||||
conv_tr1,
|
|
||||||
noise,
|
|
||||||
res1,
|
|
||||||
res2,
|
|
||||||
res3,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for DecoderBlock {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
xs.apply(&self.snake1)?
|
|
||||||
.apply(&self.conv_tr1)?
|
|
||||||
.apply(&self.noise.as_ref())?
|
|
||||||
.apply(&self.res1)?
|
|
||||||
.apply(&self.res2)?
|
|
||||||
.apply(&self.res3)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct EncoderBlock {
|
|
||||||
res1: ResidualUnit,
|
|
||||||
res2: ResidualUnit,
|
|
||||||
res3: ResidualUnit,
|
|
||||||
snake1: Snake1d,
|
|
||||||
conv1: Conv1d,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncoderBlock {
|
|
||||||
fn new(
|
|
||||||
out_dim: usize,
|
|
||||||
in_dim: Option<usize>,
|
|
||||||
stride: usize,
|
|
||||||
groups: usize,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let vb = vb.pp("block");
|
|
||||||
let in_dim = in_dim.unwrap_or(out_dim / 2);
|
|
||||||
let res1 = ResidualUnit::new(in_dim, 1, 7, groups, vb.pp(0))?;
|
|
||||||
let res2 = ResidualUnit::new(in_dim, 3, 7, groups, vb.pp(1))?;
|
|
||||||
let res3 = ResidualUnit::new(in_dim, 9, 7, groups, vb.pp(2))?;
|
|
||||||
let snake1 = Snake1d::new(in_dim, vb.pp(3))?;
|
|
||||||
let cfg1 = Conv1dConfig {
|
|
||||||
stride,
|
|
||||||
padding: stride.div_ceil(2),
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let conv1 = conv1d_weight_norm(in_dim, out_dim, 2 * stride, cfg1, vb.pp(4))?;
|
|
||||||
Ok(Self {
|
|
||||||
res1,
|
|
||||||
res2,
|
|
||||||
res3,
|
|
||||||
snake1,
|
|
||||||
conv1,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl candle::Module for EncoderBlock {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
xs.apply(&self.res1)?
|
|
||||||
.apply(&self.res2)?
|
|
||||||
.apply(&self.res3)?
|
|
||||||
.apply(&self.snake1)?
|
|
||||||
.apply(&self.conv1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Encoder {
|
|
||||||
conv1: Conv1d,
|
|
||||||
blocks: Vec<EncoderBlock>,
|
|
||||||
local_mha: Option<LocalMHA>,
|
|
||||||
conv2: Conv1d,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl candle::Module for Encoder {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut xs = xs.apply(&self.conv1)?;
|
|
||||||
for block in self.blocks.iter() {
|
|
||||||
xs = xs.apply(block)?
|
|
||||||
}
|
|
||||||
xs.apply(&self.conv2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Encoder {
|
|
||||||
fn new(
|
|
||||||
mut d_model: usize,
|
|
||||||
strides: &[usize],
|
|
||||||
depthwise: bool,
|
|
||||||
attn_window_size: Option<usize>,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let vb = vb.pp("block");
|
|
||||||
let mut idx = 0;
|
|
||||||
let cfg1 = Conv1dConfig {
|
|
||||||
padding: 3,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let conv1 = conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
let mut blocks = Vec::with_capacity(strides.len());
|
|
||||||
for &stride in strides.iter() {
|
|
||||||
d_model *= 2;
|
|
||||||
let groups = if depthwise { d_model / 2 } else { 1 };
|
|
||||||
let block = EncoderBlock::new(d_model, None, stride, groups, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
blocks.push(block)
|
|
||||||
}
|
|
||||||
let local_mha = match attn_window_size {
|
|
||||||
Some(w) => {
|
|
||||||
let mha = LocalMHA::new(d_model, w, 64, true, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
Some(mha)
|
|
||||||
}
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
let groups = if depthwise { d_model } else { 1 };
|
|
||||||
let cfg2 = Conv1dConfig {
|
|
||||||
padding: 3,
|
|
||||||
groups,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let conv2 = conv1d_weight_norm(d_model, d_model, 7, cfg2, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
Ok(Self {
|
|
||||||
conv1,
|
|
||||||
blocks,
|
|
||||||
local_mha,
|
|
||||||
conv2,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
enum ConvInit {
|
|
||||||
Depthwise(Conv1d, Conv1d),
|
|
||||||
Standard(Conv1d),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Decoder {
|
|
||||||
conv1: ConvInit,
|
|
||||||
local_mha: Option<LocalMHA>,
|
|
||||||
blocks: Vec<DecoderBlock>,
|
|
||||||
snake1: Snake1d,
|
|
||||||
conv2: Conv1d,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Decoder {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
in_c: usize,
|
|
||||||
mut channels: usize,
|
|
||||||
rates: &[usize],
|
|
||||||
noise: bool,
|
|
||||||
depthwise: bool,
|
|
||||||
attn_window_size: Option<usize>,
|
|
||||||
d_out: usize,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let vb = vb.pp("model");
|
|
||||||
let mut idx = 0;
|
|
||||||
let pad3 = Conv1dConfig {
|
|
||||||
padding: 3,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let conv1 = if depthwise {
|
|
||||||
let cfg1 = Conv1dConfig {
|
|
||||||
padding: 3,
|
|
||||||
groups: in_c,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let conv1 = conv1d_weight_norm(in_c, in_c, 7, cfg1, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
let conv2 = conv1d_weight_norm(in_c, channels, 1, Default::default(), vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
ConvInit::Depthwise(conv1, conv2)
|
|
||||||
} else {
|
|
||||||
let conv1 = conv1d_weight_norm(in_c, channels, 7, pad3, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
ConvInit::Standard(conv1)
|
|
||||||
};
|
|
||||||
let mut blocks = Vec::with_capacity(rates.len());
|
|
||||||
let local_mha = match attn_window_size {
|
|
||||||
Some(w) => {
|
|
||||||
let mha = LocalMHA::new(channels, w, 64, true, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
Some(mha)
|
|
||||||
}
|
|
||||||
None => None,
|
|
||||||
};
|
|
||||||
for stride in rates.iter() {
|
|
||||||
let groups = if depthwise { channels / 2 } else { 1 };
|
|
||||||
let block =
|
|
||||||
DecoderBlock::new(channels, channels / 2, *stride, noise, groups, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
channels /= 2;
|
|
||||||
blocks.push(block)
|
|
||||||
}
|
|
||||||
let snake1 = Snake1d::new(channels, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
let conv2 = conv1d_weight_norm(channels, d_out, 7, pad3, vb.pp(idx))?;
|
|
||||||
idx += 1;
|
|
||||||
Ok(Self {
|
|
||||||
conv1,
|
|
||||||
local_mha,
|
|
||||||
blocks,
|
|
||||||
snake1,
|
|
||||||
conv2,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl candle::Module for Decoder {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut xs = match &self.conv1 {
|
|
||||||
ConvInit::Standard(c) => xs.apply(c)?,
|
|
||||||
ConvInit::Depthwise(c1, c2) => xs.apply(c1)?.apply(c2)?,
|
|
||||||
};
|
|
||||||
for block in self.blocks.iter() {
|
|
||||||
xs = xs.apply(block)?
|
|
||||||
}
|
|
||||||
xs.apply(&self.snake1)?.apply(&self.conv2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn normalize(v: &Tensor) -> Result<Tensor> {
|
|
||||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/hubertsiuzdak/snac/blob/main/snac/vq.py
|
|
||||||
#[allow(unused)]
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
struct VectorQuantizer {
|
|
||||||
in_proj: Conv1d,
|
|
||||||
out_proj: Conv1d,
|
|
||||||
codebook: candle_nn::Embedding,
|
|
||||||
stride: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VectorQuantizer {
|
|
||||||
fn new(
|
|
||||||
in_dim: usize,
|
|
||||||
cb_size: usize,
|
|
||||||
cb_dim: usize,
|
|
||||||
stride: usize,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let in_proj = conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
|
|
||||||
let out_proj =
|
|
||||||
conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
|
|
||||||
let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
|
|
||||||
Ok(Self {
|
|
||||||
in_proj,
|
|
||||||
out_proj,
|
|
||||||
codebook,
|
|
||||||
stride,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode_latents(&self, latents: &Tensor) -> Result<(Tensor, Tensor)> {
|
|
||||||
let (b, d, t) = latents.dims3()?;
|
|
||||||
let encodings = latents.transpose(1, 2)?.reshape((b * t, d))?;
|
|
||||||
let encodings = normalize(&encodings)?;
|
|
||||||
let codebook = normalize(self.codebook.embeddings())?;
|
|
||||||
let dist = (encodings
|
|
||||||
.sqr()?
|
|
||||||
.sum_keepdim(1)?
|
|
||||||
.broadcast_sub(&encodings.matmul(&codebook.t()?)?)?
|
|
||||||
* 2.0)?
|
|
||||||
.broadcast_add(&codebook.sqr()?.sum_keepdim(1)?.t()?)?;
|
|
||||||
let indices = dist.argmin(1)?.reshape((b, ()))?;
|
|
||||||
let z_q = self.decode_code(&indices)?;
|
|
||||||
Ok((z_q, indices))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn encode(&self, z: &Tensor) -> Result<(Tensor, Tensor)> {
|
|
||||||
let z = if self.stride > 1 {
|
|
||||||
let (b, c, t) = z.dims3()?;
|
|
||||||
z.reshape((b, c, 1, t))?
|
|
||||||
.avg_pool2d((1, self.stride))?
|
|
||||||
.squeeze(2)?
|
|
||||||
} else {
|
|
||||||
z.clone()
|
|
||||||
};
|
|
||||||
let z_e = z.apply(&self.in_proj)?;
|
|
||||||
let (z_q, indices) = self.decode_latents(&z_e)?;
|
|
||||||
let z_q = z_q.apply(&self.out_proj)?;
|
|
||||||
let z_q = if self.stride > 1 {
|
|
||||||
repeat_interleave(&z_q, self.stride, D::Minus1)?
|
|
||||||
} else {
|
|
||||||
z_q
|
|
||||||
};
|
|
||||||
Ok((z_q, indices))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
|
||||||
embed_id.apply(&self.codebook)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
|
||||||
self.embed_code(embed_id)?.transpose(1, 2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct ResidualVectorQuantizer {
|
|
||||||
quantizers: Vec<VectorQuantizer>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResidualVectorQuantizer {
|
|
||||||
fn new(
|
|
||||||
input_dim: usize,
|
|
||||||
cb_size: usize,
|
|
||||||
cb_dim: usize,
|
|
||||||
vq_strides: &[usize],
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let vb = &vb.pp("quantizers");
|
|
||||||
let quantizers = vq_strides
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.map(|(i, stride)| VectorQuantizer::new(input_dim, cb_size, cb_dim, *stride, vb.pp(i)))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
Ok(Self { quantizers })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn encode(&self, z: &Tensor) -> Result<(Tensor, Vec<Tensor>)> {
|
|
||||||
let mut residual = z.clone();
|
|
||||||
let mut z_q = z.zeros_like()?;
|
|
||||||
let mut codes = Vec::with_capacity(self.quantizers.len());
|
|
||||||
for quantizer in self.quantizers.iter() {
|
|
||||||
let (z_q_i, indices_i) = quantizer.encode(&residual)?;
|
|
||||||
z_q = (z_q + &z_q_i)?;
|
|
||||||
residual = (residual - &z_q_i)?;
|
|
||||||
codes.push(indices_i)
|
|
||||||
}
|
|
||||||
Ok((z_q, codes))
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::wrong_self_convention)]
|
|
||||||
fn from_codes(&self, codes: &[&Tensor]) -> Result<Tensor> {
|
|
||||||
let mut sum = None;
|
|
||||||
for (quantizer, codes) in self.quantizers.iter().zip(codes.iter()) {
|
|
||||||
let z_p_i = quantizer.decode_code(codes)?;
|
|
||||||
let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
|
|
||||||
let z_q_i = repeat_interleave(&z_q_i, quantizer.stride, D::Minus1)?;
|
|
||||||
let s = match sum {
|
|
||||||
None => z_q_i,
|
|
||||||
Some(s) => (s + z_q_i)?,
|
|
||||||
};
|
|
||||||
sum = Some(s)
|
|
||||||
}
|
|
||||||
match sum {
|
|
||||||
Some(s) => Ok(s),
|
|
||||||
None => candle::bail!("empty codebooks"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn gcd(mut a: usize, mut b: usize) -> usize {
|
|
||||||
while b != 0 {
|
|
||||||
let t = b;
|
|
||||||
b = a % b;
|
|
||||||
a = t;
|
|
||||||
}
|
|
||||||
a
|
|
||||||
}
|
|
||||||
|
|
||||||
fn lcm(a: usize, b: usize) -> usize {
|
|
||||||
a / gcd(a, b) * b
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/hubertsiuzdak/snac/blob/main/snac/snac.py
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Model {
|
|
||||||
pub encoder: Encoder,
|
|
||||||
pub quantizer: ResidualVectorQuantizer,
|
|
||||||
pub decoder: Decoder,
|
|
||||||
pub hop_length: usize,
|
|
||||||
pub config: Config,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let encoder = Encoder::new(
|
|
||||||
cfg.encoder_dim,
|
|
||||||
&cfg.encoder_rates,
|
|
||||||
cfg.depthwise,
|
|
||||||
cfg.attn_window_size,
|
|
||||||
vb.pp("encoder"),
|
|
||||||
)?;
|
|
||||||
let latent_dim = cfg.encoder_dim * 2usize.pow(cfg.encoder_rates.len() as u32);
|
|
||||||
let quantizer = ResidualVectorQuantizer::new(
|
|
||||||
latent_dim,
|
|
||||||
cfg.codebook_size,
|
|
||||||
cfg.codebook_dim,
|
|
||||||
&cfg.vq_strides,
|
|
||||||
vb.pp("quantizer"),
|
|
||||||
)?;
|
|
||||||
let decoder = Decoder::new(
|
|
||||||
latent_dim,
|
|
||||||
cfg.decoder_dim,
|
|
||||||
&cfg.decoder_rates,
|
|
||||||
cfg.noise,
|
|
||||||
cfg.depthwise,
|
|
||||||
cfg.attn_window_size,
|
|
||||||
/* d_out */ 1,
|
|
||||||
vb.pp("decoder"),
|
|
||||||
)?;
|
|
||||||
let hop_length = cfg.encoder_rates.iter().product::<usize>();
|
|
||||||
Ok(Self {
|
|
||||||
encoder,
|
|
||||||
decoder,
|
|
||||||
quantizer,
|
|
||||||
config: cfg.clone(),
|
|
||||||
hop_length,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn preprocess(&self, audio_data: &Tensor) -> Result<Tensor> {
|
|
||||||
let len = audio_data.dim(D::Minus1)?;
|
|
||||||
let lcm = lcm(
|
|
||||||
self.config.vq_strides[0],
|
|
||||||
self.config.attn_window_size.unwrap_or(1),
|
|
||||||
);
|
|
||||||
let pad_to = self.hop_length * lcm;
|
|
||||||
let right_pad = len.div_ceil(pad_to) * pad_to - len;
|
|
||||||
let audio_data = audio_data.pad_with_zeros(D::Minus1, 0, right_pad)?;
|
|
||||||
Ok(audio_data)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn encode(&self, audio_data: &Tensor) -> Result<Vec<Tensor>> {
|
|
||||||
let audio_data = self.preprocess(audio_data)?;
|
|
||||||
let z = self.encoder.forward(&audio_data)?;
|
|
||||||
let (_, codes) = self.quantizer.encode(&z)?;
|
|
||||||
Ok(codes)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn decode(&self, audio_codes: &[&Tensor]) -> Result<Tensor> {
|
|
||||||
let audio_values = self.quantizer.from_codes(audio_codes)?;
|
|
||||||
audio_values.apply(&self.decoder)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn config(&self) -> &Config {
|
|
||||||
&self.config
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn num_codebooks(&self) -> usize {
|
|
||||||
self.quantizer.quantizers.len()
|
|
||||||
}
|
|
||||||
}
|
|
@ -68,7 +68,6 @@ impl ResnetBlock2D {
|
|||||||
padding: 1,
|
padding: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
|
||||||
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
|
||||||
@ -84,7 +83,6 @@ impl ResnetBlock2D {
|
|||||||
padding: 0,
|
padding: 0,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
Some(conv2d(
|
Some(conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -198,7 +198,7 @@ pub fn log_mel_spectrogram_<T: Float>(
|
|||||||
let samples = {
|
let samples = {
|
||||||
let mut samples_padded = samples.to_vec();
|
let mut samples_padded = samples.to_vec();
|
||||||
let to_add = n_len * fft_step - samples.len();
|
let to_add = n_len * fft_step - samples.len();
|
||||||
samples_padded.extend(std::iter::repeat_n(zero, to_add));
|
samples_padded.extend(std::iter::repeat(zero).take(to_add));
|
||||||
samples_padded
|
samples_padded
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -248,14 +248,12 @@ impl AudioEncoder {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let cfg2 = Conv1dConfig {
|
let cfg2 = Conv1dConfig {
|
||||||
padding: 1,
|
padding: 1,
|
||||||
stride: 2,
|
stride: 2,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||||
|
@ -244,14 +244,12 @@ impl AudioEncoder {
|
|||||||
stride: 1,
|
stride: 1,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let cfg2 = Conv1dConfig {
|
let cfg2 = Conv1dConfig {
|
||||||
padding: 1,
|
padding: 1,
|
||||||
stride: 2,
|
stride: 2,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||||
|
@ -54,25 +54,3 @@ fn sample_with_top_k() -> Result<()> {
|
|||||||
assert_eq!(token, 2);
|
assert_eq!(token, 2);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn sample_gumbel() -> Result<()> {
|
|
||||||
let mut logits_process = LogitsProcessor::from_sampling(
|
|
||||||
42,
|
|
||||||
candle_transformers::generation::Sampling::GumbelSoftmax { temperature: 1.0 },
|
|
||||||
);
|
|
||||||
let logits = Tensor::new(&[-1.0, 0.0, 0.2, 1.0], &Device::Cpu)?;
|
|
||||||
let sm = candle_nn::ops::softmax(&logits, 0)?.to_vec1::<f64>()?;
|
|
||||||
let mut counts = vec![0f64; 4];
|
|
||||||
let samples = 100000;
|
|
||||||
for _ in 0..samples {
|
|
||||||
let token = logits_process.sample(&logits)?;
|
|
||||||
counts[token as usize] += 1f64 / samples as f64;
|
|
||||||
}
|
|
||||||
for i in 0..4 {
|
|
||||||
if (counts[i] - sm[i]).abs() > 0.05 {
|
|
||||||
panic!("pr mismatch {counts:?} {sm:?}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
@ -177,7 +177,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
|||||||
let samples = {
|
let samples = {
|
||||||
let mut samples_padded = samples.to_vec();
|
let mut samples_padded = samples.to_vec();
|
||||||
let to_add = n_len * fft_step - samples.len();
|
let to_add = n_len * fft_step - samples.len();
|
||||||
samples_padded.extend(std::iter::repeat_n(zero, to_add));
|
samples_padded.extend(std::iter::repeat(zero).take(to_add));
|
||||||
samples_padded
|
samples_padded
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -98,7 +98,6 @@ impl ConvBlock {
|
|||||||
stride,
|
stride,
|
||||||
groups: 1,
|
groups: 1,
|
||||||
dilation: 1,
|
dilation: 1,
|
||||||
cudnn_fwd_algo: None,
|
|
||||||
};
|
};
|
||||||
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
|
||||||
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
|
||||||
|
Reference in New Issue
Block a user