Compare commits

..

5 Commits

Author SHA1 Message Date
543b5b5898 Update for the latest cudarc. 2025-04-11 14:02:41 +02:00
c87f0fa5d6 Merge remote-tracking branch 'origin/main' into cuda-graph-exp 2025-04-11 13:47:35 +02:00
1bb68854d3 Tweaks to the graph experiment. 2024-10-03 17:12:52 +02:00
b2956857ef More cuda graph attempts. 2024-10-03 12:43:08 +02:00
9076dee432 Cuda graph experiments. 2024-10-03 08:43:00 +02:00
43 changed files with 396 additions and 956 deletions

40
.github/workflows/book-cd.yml vendored Normal file
View File

@ -0,0 +1,40 @@
name: Deploy Rust book
on:
push:
branches:
- main
jobs:
deploy:
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir mdbook
curl -sSL $url | tar -xz --directory=./mdbook
echo `pwd`/mdbook >> $GITHUB_PATH
- name: Deploy GitHub Pages
run: |
# This assumes your book is in the root of your repository.
# Just add a `cd` here if you need to change to another directory.
cd candle-book
mdbook build
git worktree add gh-pages
git config user.name "Deploy from CI"
git config user.email ""
cd gh-pages
# Delete the ref to avoid keeping history.
git update-ref -d refs/heads/gh-pages
rm -rf *
mv ../book/* .
git add .
git commit -m "Deploy $GITHUB_SHA to gh-pages"
git push --force --set-upstream origin gh-pages

29
.github/workflows/book.yml vendored Normal file
View File

@ -0,0 +1,29 @@
name: CI
on:
pull_request:
jobs:
test:
name: Test candle-book
runs-on: ubuntu-latest
permissions:
contents: write # To push a branch
pull-requests: write # To create a PR from that branch
steps:
- uses: actions/checkout@master
- name: Install Rust
run: |
rustup set profile minimal
rustup toolchain install stable
rustup default stable
- name: Install latest mdbook
run: |
tag=$(curl 'https://api.github.com/repos/rust-lang/mdbook/releases/latest' | jq -r '.tag_name')
url="https://github.com/rust-lang/mdbook/releases/download/${tag}/mdbook-${tag}-x86_64-unknown-linux-gnu.tar.gz"
mkdir bin
curl -sSL $url | tar -xz --directory=bin
echo "$(pwd)/bin" >> $GITHUB_PATH
- name: Run tests
run: cd candle-book && cargo build && mdbook test -L ../target/debug/deps/

View File

@ -3,6 +3,7 @@ members = [
"candle-core",
"candle-datasets",
"candle-examples",
"candle-book",
"candle-nn",
"candle-pyo3",
"candle-transformers",
@ -11,7 +12,6 @@ members = [
"tensor-tools",
]
exclude = [
"candle-book",
"candle-flash-attn",
"candle-kernels",
"candle-metal-kernels",
@ -20,7 +20,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.9.0-alpha.3"
version = "0.9.0-alpha.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.3" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.3" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.3" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.3" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.3" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.3" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.3" }
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.1" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.1" }
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.1" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.1" }
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.1" }
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.15.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
cudarc = { version = "0.14.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
@ -70,9 +70,9 @@ tokenizers = { version = "0.21.0", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.3.1"
ug-cuda = "0.3.1"
ug-metal = "0.3.1"
ug = "0.2.0"
ug-cuda = "0.2.0"
ug-metal = "0.2.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}

View File

@ -42,7 +42,7 @@ clap = { workspace = true }
criterion = { workspace = true }
[features]
default = []
default = ["cuda"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]

View File

@ -6,18 +6,99 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
// xs: [1024, 64, 1924], c Tensor[dims 128, 64, 8; f32, cuda:0] Conv1dConfig { padding: 0, stride: 4, dilation: 1, groups: 1 }
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
let x = Tensor::randn(0f32, 1.0, (1024, 64, 1924), &device)?;
let c = Tensor::randn(0f32, 1.0, (128, 64, 8), &device)?;
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
drop(_x1);
for _ in 0..20 {
let start_time = std::time::Instant::now();
let _x1 = x.conv1d(&c, 0, 4, 1, 1)?;
const USE_CUDA_GRAPH: bool = true;
fn cuda_graph() -> Result<()> {
let device = Device::new_cuda_with_stream(0)?;
let cu_device = match &device {
Device::Cuda(dev) => dev,
_ => unreachable!(),
};
let cu_stream = cu_device.cuda_stream();
{
// load_ptx cannot be called while capturing the stream so we need this to happen
// beforehand.
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let _x = x.mul(&u)?.broadcast_add(&v)?;
let _x = x.affine(1., 0.5)?;
x.slice_set(&u, 0, 0)?;
device.synchronize()?;
}
if USE_CUDA_GRAPH {
cu_stream.begin_capture(
cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
)?;
}
{
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let v = Tensor::zeros((4096, 1), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
for _i in 0..100 {
// x.slice_set(&u, 0, 0)?;
// x.broadcast_add(&v)?;
x = x.affine(1., 0.5)?;
// x = (&u + &x)?;
}
}
if USE_CUDA_GRAPH {
println!("capturing graph");
let cu_graph = match cu_stream.end_capture(
cudarc::driver::sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
)? {
None => anyhow::bail!("no graph captured"),
Some(cu_graph) => cu_graph,
};
println!("graph captured!");
for i in 1..100 {
println!("graph exec {i}");
cu_graph.launch()?;
println!("sync");
if let Err(err) = device.synchronize() {
println!("err: {err:?}")
}
println!("done syncing");
}
} else {
device.synchronize()?;
println!("conv1d: {:?}", start_time.elapsed());
}
Ok(())
}
fn main() -> Result<()> {
cuda_graph()?;
return Ok(());
}
fn _matmul() -> Result<()> {
let device = Device::new_cuda_with_stream(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("fp32: {:?}", start_time.elapsed());
drop(_x1);
candle_core::cuda::set_gemm_reduced_precision_f32(true);
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
let _x1 = x.matmul(&x)?;
drop(_x1);
let start_time = std::time::Instant::now();
let _x1 = x.matmul(&x)?;
device.synchronize()?;
println!("tf32: {:?}", start_time.elapsed());
drop(_x1);
Ok(())
}

View File

@ -14,7 +14,6 @@ pub struct ParamsConv1D {
pub(crate) padding: usize,
pub(crate) stride: usize,
pub(crate) dilation: usize,
pub(crate) cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl ParamsConv1D {
@ -55,7 +54,7 @@ impl ParamsConvTranspose1D {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum CudnnFwdAlgo {
ImplicitGemm,
ImplicitPrecompGemm,
@ -152,19 +151,6 @@ impl Tensor {
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
self.conv1d_with_algo(kernel, padding, stride, dilation, groups, None)
}
/// Applies a 1D convolution over the input tensor.
pub fn conv1d_with_algo(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
) -> Result<Self> {
let (c_out, c_in_k, k_size) = kernel.dims3()?;
let (b_size, c_in, l_in) = self.dims3()?;
@ -188,7 +174,6 @@ impl Tensor {
padding,
stride,
dilation,
cudnn_fwd_algo,
};
if groups == 1 {
self.conv1d_single_group(kernel, &params)
@ -293,18 +278,6 @@ impl Tensor {
stride: usize,
dilation: usize,
groups: usize,
) -> Result<Self> {
self.conv2d_with_algo(kernel, padding, stride, dilation, groups, None)
}
pub fn conv2d_with_algo(
&self,
kernel: &Self,
padding: usize,
stride: usize,
dilation: usize,
groups: usize,
cudnn_fwd_algo: Option<CudnnFwdAlgo>,
) -> Result<Self> {
let (b_size, c_in, i_h, i_w) = self.dims4()?;
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
@ -324,7 +297,7 @@ impl Tensor {
padding,
stride,
dilation,
cudnn_fwd_algo,
cudnn_fwd_algo: None,
};
if groups == 1 {
self.conv2d_single_group(kernel, &params)

View File

@ -1289,15 +1289,6 @@ impl Map2 for MatMul {
} else {
Parallelism::None
};
let (b, m, n, k) = if b_skip == 0 && a_skip == m * k {
// a_skip and c_skip should be updated but step is always 0 so
// it wouldn't matter.
(1, b * m, n, k)
} else if a_skip == 0 && b_skip == n * k {
(1, m, b * n, k)
} else {
(b, m, n, k)
};
for step in 0..b {
let lhs_p = &lhs[step * a_skip..];
let rhs_p = &rhs[step * b_skip..];

View File

@ -122,104 +122,3 @@ pub(crate) fn launch_conv2d<
}
Ok(())
}
pub(crate) fn launch_conv1d<
T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType,
Y: cudarc::cudnn::CudnnDataType,
>(
src: &CudaView<T>,
src_l: &crate::Layout,
filter: &CudaView<T>,
dst: &mut CudaSlice<T>,
params: &crate::conv::ParamsConv1D,
dev: &crate::cuda_backend::CudaDevice,
) -> crate::Result<()> {
use crate::conv::CudnnFwdAlgo as CandleAlgo;
use cudarc::cudnn::sys::cudnnConvolutionFwdAlgo_t as A;
let device_id = dev.id();
let cudnn = CUDNN.with(|cudnn| {
if let Some(cudnn) = cudnn.borrow().get(&device_id) {
return Ok(cudnn.clone());
}
let c = Cudnn::new(dev.cuda_stream());
if let Ok(c) = &c {
cudnn.borrow_mut().insert(device_id, c.clone());
}
c
})?;
let conv = cudnn.create_conv2d::<Y>(
/* pad */ [params.padding as i32, 0],
/* stride */ [params.stride as i32, 1],
/* dilation */ [params.dilation as i32, 1],
cudarc::cudnn::sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION,
)?;
// https://docs.nvidia.com/deeplearning/cudnn/backend/latest/api/cudnn-ops-library.html#cudnnsettensornddescriptor
// > Tensors are restricted to having at least 4 dimensions, and at most CUDNN_DIM_MAX
// > dimensions (defined in cudnn.h). When working with lower dimensional data, it is
// > recommended that the user create a 4D tensor, and set the size along unused dimensions
// > to 1.
let x_shape = [
params.b_size as i32,
params.c_in as i32,
params.l_in as i32,
1,
];
// Note that `src` already starts at the proper offset.
let x = if src_l.is_contiguous() {
cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
x_shape,
)?
} else {
let s = src_l.stride();
cudnn.create_4d_tensor_ex::<T>(x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, 1i32])?
};
let w = cudnn.create_4d_filter::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[
params.c_out as i32,
params.c_in as i32,
params.k_size as i32,
1,
],
)?;
let l_out = params.l_out() as i32;
let y = cudnn.create_4d_tensor::<T>(
cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
[params.b_size as i32, params.c_out as i32, l_out, 1],
)?;
let conv1d = ConvForward {
conv: &conv,
x: &x,
w: &w,
y: &y,
};
let alg = match params.cudnn_fwd_algo {
None => conv1d.pick_algorithm()?,
Some(CandleAlgo::ImplicitGemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
Some(CandleAlgo::ImplicitPrecompGemm) => {
A::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
Some(CandleAlgo::Gemm) => A::CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
Some(CandleAlgo::Direct) => A::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
Some(CandleAlgo::Fft) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT,
Some(CandleAlgo::FftTiling) => A::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
Some(CandleAlgo::Winograd) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
Some(CandleAlgo::WinogradNonFused) => A::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
Some(CandleAlgo::Count) => A::CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
};
let workspace_size = conv1d.get_workspace_size(alg)?;
let mut workspace = dev.cuda_stream().alloc_zeros::<u8>(workspace_size)?;
unsafe {
conv1d.launch::<CudaSlice<u8>, _, _, _>(
alg,
Some(&mut workspace),
(T::one(), T::zero()),
src,
filter,
dst,
)?;
}
Ok(())
}

View File

@ -46,61 +46,11 @@ impl std::fmt::Debug for CudaDevice {
}
}
impl CudaDevice {
#[allow(clippy::missing_safety_doc)]
pub unsafe fn alloc<T: cudarc::driver::DeviceRepr>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc::<T>(len).w()
}
impl std::ops::Deref for CudaDevice {
type Target = Arc<cudarc::driver::CudaStream>;
pub fn alloc_zeros<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits>(
&self,
len: usize,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.alloc_zeros::<T>(len).w()
}
pub fn memcpy_htod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_htod(src, dst).w()
}
pub fn memcpy_dtov<T: cudarc::driver::DeviceRepr, Src: cudarc::driver::DevicePtr<T>>(
&self,
src: &Src,
) -> Result<Vec<T>> {
self.stream.memcpy_dtov(src).w()
}
pub fn memcpy_dtod<
T,
Src: cudarc::driver::DevicePtr<T>,
Dst: cudarc::driver::DevicePtrMut<T>,
>(
&self,
src: &Src,
dst: &mut Dst,
) -> Result<()> {
self.stream.memcpy_dtod(src, dst).w()
}
pub fn memcpy_stod<
T: cudarc::driver::DeviceRepr,
Src: cudarc::driver::HostSlice<T> + ?Sized,
>(
&self,
src: &Src,
) -> Result<cudarc::driver::CudaSlice<T>> {
self.stream.memcpy_stod(src).w()
fn deref(&self) -> &Self::Target {
&self.stream
}
}
@ -176,7 +126,7 @@ impl CudaDevice {
let slice = match dtype {
DType::U8 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u8>(elem_count)? };
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u8;
@ -188,7 +138,7 @@ impl CudaDevice {
}
DType::U32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<u32>(elem_count)? };
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as u32;
@ -200,7 +150,7 @@ impl CudaDevice {
}
DType::I64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<i64>(elem_count)? };
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as i64;
@ -212,7 +162,7 @@ impl CudaDevice {
}
DType::BF16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<bf16>(elem_count)? };
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = bf16::from_f64(v);
@ -224,7 +174,7 @@ impl CudaDevice {
}
DType::F16 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f16>(elem_count)? };
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = f16::from_f64(v);
@ -236,7 +186,7 @@ impl CudaDevice {
}
DType::F32 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f32>(elem_count)? };
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
let v = v as f32;
@ -248,7 +198,7 @@ impl CudaDevice {
}
DType::F64 => {
// SAFETY: Set later by running the fill kernel.
let data = unsafe { self.alloc::<f64>(elem_count) }?;
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
let mut builder = self.stream.launch_builder(&func);
builder.arg(&data);
@ -375,31 +325,31 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc_zeros::<u8>(elem_count)?;
let data = self.alloc_zeros::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc_zeros::<u32>(elem_count)?;
let data = self.alloc_zeros::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc_zeros::<i64>(elem_count)?;
let data = self.alloc_zeros::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc_zeros::<bf16>(elem_count)?;
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc_zeros::<f16>(elem_count)?;
let data = self.alloc_zeros::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc_zeros::<f32>(elem_count)?;
let data = self.alloc_zeros::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc_zeros::<f64>(elem_count)?;
let data = self.alloc_zeros::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
@ -423,12 +373,12 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count)? };
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count)? };
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
curand.0.fill_with_uniform(&mut data).w()?;
CudaStorageSlice::F64(data)
}
@ -467,7 +417,7 @@ impl BackendDevice for CudaDevice {
.w()?
}
DType::F32 => {
let mut data = unsafe { self.alloc::<f32>(elem_count_round)? };
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
curand
.0
.fill_with_normal(&mut data, mean as f32, std as f32)
@ -475,7 +425,7 @@ impl BackendDevice for CudaDevice {
CudaStorageSlice::F32(data)
}
DType::F64 => {
let mut data = unsafe { self.alloc::<f64>(elem_count_round)? };
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
curand.0.fill_with_normal(&mut data, mean, std).w()?;
CudaStorageSlice::F64(data)
}
@ -494,31 +444,31 @@ impl BackendDevice for CudaDevice {
let elem_count = shape.elem_count();
let slice = match dtype {
DType::U8 => {
let data = self.alloc::<u8>(elem_count)?;
let data = self.alloc::<u8>(elem_count).w()?;
CudaStorageSlice::U8(data)
}
DType::U32 => {
let data = self.alloc::<u32>(elem_count)?;
let data = self.alloc::<u32>(elem_count).w()?;
CudaStorageSlice::U32(data)
}
DType::I64 => {
let data = self.alloc::<i64>(elem_count)?;
let data = self.alloc::<i64>(elem_count).w()?;
CudaStorageSlice::I64(data)
}
DType::BF16 => {
let data = self.alloc::<bf16>(elem_count)?;
let data = self.alloc::<bf16>(elem_count).w()?;
CudaStorageSlice::BF16(data)
}
DType::F16 => {
let data = self.alloc::<f16>(elem_count)?;
let data = self.alloc::<f16>(elem_count).w()?;
CudaStorageSlice::F16(data)
}
DType::F32 => {
let data = self.alloc::<f32>(elem_count)?;
let data = self.alloc::<f32>(elem_count).w()?;
CudaStorageSlice::F32(data)
}
DType::F64 => {
let data = self.alloc::<f64>(elem_count)?;
let data = self.alloc::<f64>(elem_count).w()?;
CudaStorageSlice::F64(data)
}
};
@ -531,31 +481,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
let slice = match T::cpu_storage_ref(s) {
CpuStorageRef::U8(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorageRef::U32(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorageRef::I64(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorageRef::BF16(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorageRef::F16(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorageRef::F32(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorageRef::F64(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -568,31 +518,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.memcpy_stod(storage)?;
let data = self.memcpy_stod(storage).w()?;
CudaStorageSlice::F64(data)
}
};
@ -605,31 +555,31 @@ impl BackendDevice for CudaDevice {
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
let slice = match storage {
CpuStorage::U8(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::U8(data)
}
CpuStorage::U32(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::U32(data)
}
CpuStorage::I64(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::I64(data)
}
CpuStorage::BF16(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::BF16(data)
}
CpuStorage::F16(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::F16(data)
}
CpuStorage::F32(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::F32(data)
}
CpuStorage::F64(storage) => {
let data = self.memcpy_stod(&storage)?;
let data = self.memcpy_stod(&storage).w()?;
CudaStorageSlice::F64(data)
}
};

View File

@ -39,7 +39,7 @@ impl SlicePtrOrNull<usize> {
let ds = if l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat())?)
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat()).w()?)
};
Ok(ds)
}
@ -89,7 +89,7 @@ impl Map1 for Affine {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("affine"), &kernels::AFFINE)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el)? };
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -120,7 +120,7 @@ impl Map1 for Elu {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el)? };
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -134,7 +134,6 @@ impl Map1 for Elu {
}
}
#[allow(unused)]
struct Im2Col1D {
l_k: usize,
stride: usize,
@ -143,7 +142,6 @@ struct Im2Col1D {
}
impl Im2Col1D {
#[allow(unused)]
fn l_out(&self, l: usize) -> usize {
(l + 2 * self.padding - self.dilation * (self.l_k - 1) - 1) / self.stride + 1
}
@ -159,15 +157,15 @@ impl Map1 for Im2Col1D {
let shape = layout.shape();
let dims = shape.dims();
let l_out = self.l_out(dims[2]);
let threads = dims[0] * l_out * dims[1];
let cfg = LaunchConfig::for_num_elems(threads as u32);
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
let dst_el = dims[0] * l_out * dims[1] * self.l_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col1d"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(threads * self.l_k)? };
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let mut builder = func.builder();
barg!(builder, threads);
barg!(builder, dst_el);
barg!(builder, l_out);
barg!(builder, self.l_k);
barg!(builder, self.stride);
@ -212,11 +210,11 @@ impl Map1 for Im2Col {
let (h_out, w_out) = self.hw_out(dims[2], dims[3]);
let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?;
let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("im2col"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(dst_el)? };
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let mut builder = func.builder();
barg!(builder, dst_el);
barg!(builder, h_out);
@ -251,7 +249,7 @@ impl Map1 for Powf {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("upowf"), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el)? };
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -304,7 +302,9 @@ impl Map1Any for FastReduce<'_> {
block_dim: (block_dim as u32, 1, 1),
shared_mem_bytes: 0,
};
let ds = dev.memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat())?;
let ds = dev
.memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat())
.w()?;
let src = &src.slice(layout.start_offset()..);
let (name, check_empty, return_index) = match self.1 {
ReduceOp::Sum => ("fast_sum", false, false),
@ -319,7 +319,7 @@ impl Map1Any for FastReduce<'_> {
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::REDUCE)?;
if return_index {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<u32>(dst_el)? };
let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
let mut builder = func.builder();
barg!(builder, src_el);
barg!(builder, el_to_sum_per_block);
@ -332,7 +332,7 @@ impl Map1Any for FastReduce<'_> {
Ok(S::U32(out))
} else {
// SAFETY: filled in by the follow up kernel.
let out = unsafe { dev.alloc::<T>(dst_el)? };
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let mut builder = func.builder();
barg!(builder, src_el);
barg!(builder, el_to_sum_per_block);
@ -362,7 +362,7 @@ impl<U: UnaryOpT> Map1 for U {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let mut out = unsafe { dev.alloc::<T>(el_count)? };
let mut out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let mut builder = func.builder();
barg!(builder, el_count);
barg!(builder, dims.len());
@ -403,7 +403,7 @@ impl Map1 for IndexSelect<'_> {
};
let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims();
let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat())?;
let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat()).w()?;
let src = match src_l.contiguous_offsets() {
Some((o1, o2)) => src.slice(o1..o2),
None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
@ -416,7 +416,7 @@ impl Map1 for IndexSelect<'_> {
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el)? };
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let mut builder = func.builder();
barg!(builder, dst_el);
barg!(builder, ids_dims.len());
@ -471,7 +471,7 @@ impl Map1 for Gather<'_> {
let ids_dim_sz = ids_l.dims()[dim];
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el)? };
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, ids);
@ -608,7 +608,7 @@ impl Map2 for Conv1D<'_> {
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv1d"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el)? };
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let ds = if dims.len() == 3 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else if dims.len() == 2 {
@ -616,7 +616,7 @@ impl Map2 for Conv1D<'_> {
} else {
crate::bail!("unexpected input shape for conv1d {dims:?}")
};
let ds = dev.memcpy_stod(&ds)?;
let ds = dev.memcpy_stod(&ds).w()?;
let mut builder = func.builder();
barg!(builder, el, l_out, p.stride, p.padding, p.dilation);
builder.arg(&ds);
@ -651,7 +651,7 @@ impl Map2 for Conv2D<'_> {
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el)? };
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv2d"), &kernels::CONV)?;
let ds = if dims.len() == 4 {
@ -659,7 +659,7 @@ impl Map2 for Conv2D<'_> {
} else {
crate::bail!("unexpected input shape for conv2d {dims:?}")
};
let ds = dev.memcpy_stod(&ds)?;
let ds = dev.memcpy_stod(&ds).w()?;
let mut builder = func.builder();
barg!(builder, el, out_w, out_h, p.stride, p.padding, p.dilation);
builder.arg(&ds);
@ -687,7 +687,7 @@ impl Map1 for Col2Im1D {
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let dst_el = b_size * c_out * l_out;
let mut im = unsafe { dev.alloc::<T>(dst_el)? };
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), &kernels::CONV)?;
@ -722,7 +722,7 @@ impl Map2 for ConvTranspose1D<'_> {
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el)? };
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), &kernels::CONV)?;
let ds = if dims.len() == 3 {
@ -730,7 +730,7 @@ impl Map2 for ConvTranspose1D<'_> {
} else {
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
};
let ds = dev.memcpy_stod(&ds)?;
let ds = dev.memcpy_stod(&ds).w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, l_out);
@ -770,7 +770,7 @@ impl Map2 for ConvTranspose2D<'_> {
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el)? };
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose2d"), &kernels::CONV)?;
let ds = if dims.len() == 4 {
@ -778,7 +778,7 @@ impl Map2 for ConvTranspose2D<'_> {
} else {
crate::bail!("unexpected input shape for conv_transpose2d {dims:?}")
};
let ds = dev.memcpy_stod(&ds)?;
let ds = dev.memcpy_stod(&ds).w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, out_w);
@ -837,8 +837,8 @@ impl Map1 for Pool2D {
};
let func = dev.get_or_load_func(&kernel_name::<T>(kname), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el)? };
let ds = dev.memcpy_stod(&ds)?;
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let ds = dev.memcpy_stod(&ds).w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, self.w_k);
@ -876,8 +876,8 @@ impl Map1 for UpsampleNearest2D {
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("upsample_nearest2d"), &kernels::CONV)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el)? };
let ds = dev.memcpy_stod(&ds)?;
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let ds = dev.memcpy_stod(&ds).w()?;
let scale_w = dims[2] as f64 / out_w as f64;
let scale_h = dims[3] as f64 / out_h as f64;
let mut builder = func.builder();
@ -930,12 +930,13 @@ impl Map2 for WhereCond<'_> {
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev
.memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?;
.memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())
.w()?;
let t = &t.slice(layout_t.start_offset()..);
let f = &f.slice(layout_f.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::TERNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el)? };
let out = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -966,13 +967,16 @@ impl<U: crate::op::BinaryOpT> Map2 for U {
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?)
SlicePtrOrNull::Ptr(
dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
};
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>(U::KERNEL), &kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(elem_count)? };
let out = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let mut builder = func.builder();
barg!(builder, elem_count);
barg!(builder, dims.len());
@ -1003,7 +1007,10 @@ impl Map2Any for Cmp {
let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
SlicePtrOrNull::Null
} else {
SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?)
SlicePtrOrNull::Ptr(
dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())
.w()?,
)
};
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
@ -1017,7 +1024,7 @@ impl Map2Any for Cmp {
};
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<u8>(elem_count)? };
let out = unsafe { dev.alloc::<u8>(elem_count) }.w()?;
let mut builder = func.builder();
barg!(builder, elem_count);
barg!(builder, dims.len());
@ -1201,6 +1208,7 @@ fn gemm_config<T>(
mnk: (m, n, k),
})?,
};
Ok(StridedBatchedConfig {
batch_size: b as i32,
gemm,
@ -1261,7 +1269,7 @@ impl BackendStorage for CudaStorage {
let func = dev.get_or_load_func(&kernel_name, &kernels::CAST)?;
let slice = match dtype {
DType::U8 => {
let out = unsafe { dev.alloc::<u8>(el)? };
let out = unsafe { dev.alloc::<u8>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1272,7 +1280,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::U8(out)
}
DType::U32 => {
let out = unsafe { dev.alloc::<u32>(el)? };
let out = unsafe { dev.alloc::<u32>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1283,7 +1291,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::U32(out)
}
DType::I64 => {
let out = unsafe { dev.alloc::<i64>(el)? };
let out = unsafe { dev.alloc::<i64>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1294,7 +1302,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::I64(out)
}
DType::BF16 => {
let out = unsafe { dev.alloc::<bf16>(el)? };
let out = unsafe { dev.alloc::<bf16>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1305,7 +1313,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::BF16(out)
}
DType::F16 => {
let out = unsafe { dev.alloc::<f16>(el)? };
let out = unsafe { dev.alloc::<f16>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1316,7 +1324,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::F16(out)
}
DType::F32 => {
let out = unsafe { dev.alloc::<f32>(el)? };
let out = unsafe { dev.alloc::<f32>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1327,7 +1335,7 @@ impl BackendStorage for CudaStorage {
CudaStorageSlice::F32(out)
}
DType::F64 => {
let out = unsafe { dev.alloc::<f64>(el)? };
let out = unsafe { dev.alloc::<f64>(el) }.w()?;
let mut builder = func.builder();
barg!(builder, el);
barg!(builder, dims.len());
@ -1437,7 +1445,6 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}
#[cfg(not(feature = "cudnn"))]
fn conv1d(
&self,
l: &Layout,
@ -1466,11 +1473,12 @@ impl BackendStorage for CudaStorage {
let n = params.c_out;
let k = params.k_size * params.c_in;
let m = l_out;
let col_l = Layout::contiguous((b * m, k));
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
@ -1478,9 +1486,10 @@ impl BackendStorage for CudaStorage {
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, l_out, n)).transpose(1, 2)?;
let mut res_t = unsafe { self.device().alloc_uninit(res_l.shape(), res.dtype())? };
@ -1488,72 +1497,6 @@ impl BackendStorage for CudaStorage {
Ok(res_t)
}
#[cfg(feature = "cudnn")]
fn conv1d(
&self,
inp_l: &Layout,
kernel: &Self,
kernel_l: &Layout,
params: &crate::conv::ParamsConv1D,
) -> Result<Self> {
let device = self.device().clone();
if !kernel_l.is_contiguous() {
let slice = Conv1D(params).map(&self.slice, inp_l, &kernel.slice, kernel_l, &device)?;
return Ok(Self { slice, device });
}
let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size;
let slice = match (&self.slice, &kernel.slice) {
(S::U8(inp), S::U8(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
crate::cudnn::launch_conv1d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::U8(out)
}
(S::BF16(inp), S::BF16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
// version.
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
crate::cudnn::launch_conv1d::<bf16, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::BF16(out)
}
(S::F16(inp), S::F16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
crate::cudnn::launch_conv1d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F16(out)
}
(S::F32(inp), S::F32(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
crate::cudnn::launch_conv1d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F32(out)
}
(S::F64(inp), S::F64(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
crate::cudnn::launch_conv1d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F64(out)
}
(S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv1d does not support u32"))?,
(S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv1d does not support i64"))?,
_ => Err(CudaError::InternalError("dtype mismatch in conv1d"))?,
};
Ok(Self { slice, device })
}
fn conv_transpose1d(
&self,
l: &Layout,
@ -1644,11 +1587,12 @@ impl BackendStorage for CudaStorage {
let n = params.c_out;
let k = params.k_h * params.k_w * params.c_in;
let m = h_out * w_out;
let col_l = Layout::contiguous((b * m, k));
let col_l = Layout::contiguous((b, m, k));
let res = if kernel_l.is_contiguous() {
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
} else {
// Make the kernel contiguous if not already the case.
let mut kernel_c = unsafe {
@ -1656,9 +1600,10 @@ impl BackendStorage for CudaStorage {
.alloc_uninit(kernel_l.shape(), kernel.dtype())?
};
kernel.copy_strided_src(&mut kernel_c, 0, kernel_l)?;
let kernel_l =
Layout::contiguous_with_offset((n, k), kernel_l.start_offset()).transpose(0, 1)?;
col.matmul(kernel, (1, b * m, n, k), &col_l, &kernel_l)?
let kernel_l = Layout::contiguous_with_offset((1, n, k), kernel_l.start_offset())
.transpose(1, 2)?
.broadcast_as((b, k, n))?;
col.matmul(kernel, (b, m, n, k), &col_l, &kernel_l)?
};
let res_l = Layout::contiguous((b, h_out, w_out, n))
.transpose(1, 2)?
@ -1687,7 +1632,7 @@ impl BackendStorage for CudaStorage {
(S::U8(inp), S::U8(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<u8>(dst_el)? };
let mut out = unsafe { device.alloc::<u8>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<u8, u8>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::U8(out)
@ -1695,7 +1640,7 @@ impl BackendStorage for CudaStorage {
(S::BF16(inp), S::BF16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<bf16>(dst_el)? };
let mut out = unsafe { device.alloc::<bf16>(dst_el) }.w()?;
// Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16"
// version.
// https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88
@ -1706,7 +1651,7 @@ impl BackendStorage for CudaStorage {
(S::F16(inp), S::F16(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f16>(dst_el)? };
let mut out = unsafe { device.alloc::<f16>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f16, f16>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F16(out)
@ -1714,7 +1659,7 @@ impl BackendStorage for CudaStorage {
(S::F32(inp), S::F32(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f32>(dst_el)? };
let mut out = unsafe { device.alloc::<f32>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f32, f32>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F32(out)
@ -1722,7 +1667,7 @@ impl BackendStorage for CudaStorage {
(S::F64(inp), S::F64(k)) => {
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(kernel_l.start_offset()..);
let mut out = unsafe { device.alloc::<f64>(dst_el)? };
let mut out = unsafe { device.alloc::<f64>(dst_el) }.w()?;
crate::cudnn::launch_conv2d::<f64, f64>(inp, inp_l, k, &mut out, params, &device)
.map_err(crate::Error::wrap)?;
S::F64(out)
@ -1838,7 +1783,7 @@ impl BackendStorage for CudaStorage {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<bf16>(elem_count)? };
let mut out = unsafe { dev.alloc::<bf16>(elem_count) }.w()?;
unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::BF16(out)
@ -1847,7 +1792,7 @@ impl BackendStorage for CudaStorage {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f16>(elem_count)? };
let mut out = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::F16(out)
@ -1856,7 +1801,7 @@ impl BackendStorage for CudaStorage {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f32>(elem_count)? };
let mut out = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) }
.w()?;
CudaStorageSlice::F32(out)
@ -1865,7 +1810,7 @@ impl BackendStorage for CudaStorage {
let lhs = &lhs.slice(lhs_l.start_offset()..);
let rhs = &rhs.slice(rhs_l.start_offset()..);
let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?;
let mut out = unsafe { dev.alloc::<f64>(elem_count)? };
let mut out = unsafe { dev.alloc::<f64>(elem_count) }.w()?;
unsafe {
self.device
.blas
@ -1938,7 +1883,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst)?
dev.memcpy_dtod(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_bf16", &kernels::UNARY)?;
let mut builder = func.builder();
@ -1954,7 +1899,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst)?
dev.memcpy_dtod(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_f16", &kernels::UNARY)?;
let mut builder = func.builder();
@ -1970,7 +1915,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst)?
dev.memcpy_dtod(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_f32", &kernels::UNARY)?;
let mut builder = func.builder();
@ -1986,7 +1931,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst)?
dev.memcpy_dtod(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_u8", &kernels::UNARY)?;
let mut builder = func.builder();
@ -2002,7 +1947,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst)?
dev.memcpy_dtod(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_u32", &kernels::UNARY)?;
let mut builder = func.builder();
@ -2018,7 +1963,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst)?
dev.memcpy_dtod(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_i64", &kernels::UNARY)?;
let mut builder = func.builder();
@ -2034,7 +1979,7 @@ impl BackendStorage for CudaStorage {
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
if src_l.is_contiguous() {
dev.memcpy_dtod(&src, &mut dst)?
dev.memcpy_dtod(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_f64", &kernels::UNARY)?;
let mut builder = func.builder();

View File

@ -99,7 +99,7 @@ fn dequantize_f32(
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(elem_count)? };
let dst = unsafe { dev.alloc::<f32>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -159,7 +159,7 @@ fn dequantize_f16(
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
let dst = unsafe { dev.alloc::<f16>(elem_count).w()? };
// See e.g.
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
let cfg = cudarc::driver::LaunchConfig {
@ -216,7 +216,7 @@ fn dequantize_mul_mat_vec(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows)? };
let dst = unsafe { dev.alloc::<f32>(nrows).w()? };
let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (block_num_y as u32, 1, 1),
@ -256,7 +256,7 @@ fn mul_mat_vec_via_q8_1(
let ncols_padded = pad(ncols, MATRIX_ROW_PADDING);
let y_size_in_bytes =
b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?;
let kernel_name = match dtype {
@ -274,7 +274,7 @@ fn mul_mat_vec_via_q8_1(
};
let kernel_name = format!("{kernel_name}{b_size}");
let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(nrows * b_size)? };
let dst = unsafe { dev.alloc::<f32>(nrows * b_size).w()? };
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
let (nblocks, nwarps) = match b_size {
1 => (nrows as u32, 4),
@ -329,7 +329,7 @@ fn mul_mat_via_q8_1(
let k_padded = pad(k, MATRIX_ROW_PADDING);
let y_size_in_bytes =
k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?;
let (kernel_name, mmq_x, mmq_y) = match dtype {
@ -346,7 +346,7 @@ fn mul_mat_via_q8_1(
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
};
let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?;
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols)? };
let dst = unsafe { dev.alloc::<f32>(x_rows * y_cols).w()? };
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
ceil_div(x_rows, mmq_y) as u32,
@ -378,7 +378,7 @@ impl QCudaStorage {
let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size();
let padded_size_in_bytes =
ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size();
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes)?;
let inner = device.alloc_zeros::<u8>(padded_size_in_bytes).w()?;
Ok(QCudaStorage {
data: PaddedCudaSlice {
inner,
@ -425,7 +425,8 @@ impl QCudaStorage {
let buffer = self
.device
.memcpy_dtov(&self.data.inner.slice(..self.data.len))?;
.memcpy_dtov(&self.data.inner.slice(..self.data.len))
.w()?;
let mut out = vec![0.0; elem_count];
let block_len = elem_count / self.dtype.block_size();
match self.dtype {
@ -456,7 +457,9 @@ impl QCudaStorage {
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
// Run the quantization on cpu.
let src = match &src.slice {
crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?,
crate::cuda_backend::CudaStorageSlice::F32(data) => {
self.device.memcpy_dtov(data).w()?
}
_ => crate::bail!("only f32 can be quantized"),
};
let src_len = src.len();
@ -466,9 +469,10 @@ impl QCudaStorage {
let data = qcpu_storage.data()?;
let padded_len =
data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size();
let mut inner = unsafe { self.device.alloc::<u8>(padded_len)? };
let mut inner = unsafe { self.device.alloc::<u8>(padded_len).w()? };
self.device
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?;
.memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))
.w()?;
self.data = PaddedCudaSlice {
inner,
len: data.len(),
@ -602,8 +606,10 @@ pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
};
let dtype = T::DTYPE;
let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size();
let mut inner = unsafe { device.alloc::<u8>(padded_len)? };
device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?;
let mut inner = unsafe { device.alloc::<u8>(padded_len).w()? };
device
.memcpy_htod(data, &mut inner.slice_mut(..data.len()))
.w()?;
Ok(QStorage::Cuda(QCudaStorage {
data: PaddedCudaSlice {
inner,
@ -625,9 +631,9 @@ mod test {
let el_padded = pad(el, MATRIX_ROW_PADDING);
let y_size_in_bytes =
el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size();
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes)? };
let mut y_q8_1 = unsafe { dev.alloc::<u8>(y_size_in_bytes).w()? };
let vs: Vec<f32> = (0..el).map(|v| v as f32).collect();
let y = dev.memcpy_stod(&vs)?;
let y = dev.memcpy_stod(&vs).w()?;
quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?;
Ok(())
}
@ -637,7 +643,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols).map(|v| v as f32).collect();
let y = dev.memcpy_stod(&vs)?;
let y = dev.memcpy_stod(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_vec_via_q8_1(
@ -650,7 +656,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..))?;
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
// for n = 255, n.(n+1).(2n+1) / 6 = 5559680
// Q8 means 1/256 precision.
@ -665,7 +671,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..))?;
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
assert_eq!(vs.len(), 1);
assert_eq!(vs[0], 5561851.0);
Ok(())
@ -676,7 +682,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let ncols = 256;
let vs: Vec<f32> = (0..ncols * 4).map(|v| v as f32 / 4.).collect();
let y = dev.memcpy_stod(&vs)?;
let y = dev.memcpy_stod(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -690,7 +696,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let vs = dev.memcpy_dtov(&vs.slice(..))?;
let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
/*
x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256)
@ -717,7 +723,7 @@ mod test {
let dev = CudaDevice::new(0)?;
let (x_rows, ncols, y_cols) = (4, 16, 2048);
let vs: Vec<f32> = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect();
let y = dev.memcpy_stod(&vs)?;
let y = dev.memcpy_stod(&vs).w()?;
let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?;
xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?;
let cuda_storage = mul_mat_via_q8_1(
@ -731,7 +737,7 @@ mod test {
&dev,
)?;
let vs = cuda_storage.as_cuda_slice::<f32>()?;
let _vs = dev.memcpy_dtov(&vs.slice(..))?;
let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap();
Ok(())
}
}

View File

@ -76,7 +76,7 @@ mod cuda {
Some((o1, o2)) => src.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<u32>(elem_count)? };
let dst = unsafe { dev.alloc::<u32>(elem_count) }.w()?;
let func = if self.asc {
dev.get_or_load_func(&kernel_name::<T>("asort_asc"), &kernels::SORT)?
} else {

View File

@ -53,20 +53,6 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv1d(&w, /*padding*/ 1, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 5]);
// Same as pytorch default padding: use zeros.
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
@ -177,22 +163,6 @@ fn conv2d(dev: &Device) -> Result<()> {
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = {
let t = Tensor::cat(&[&t.zeros_like()?, &t, &t.zeros_like()?], 0)?;
t.conv2d(&w, 0, 1, 1, 1)?
};
assert_eq!(res.dims(), [3, 2, 3, 3]);
assert_eq!(
test_utils::to_vec1_round(&res.i(0)?.flatten_all()?, 4)?,
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]
);
assert_eq!(
test_utils::to_vec1_round(&res.i(1)?.flatten_all()?, 4)?,
[
-4.2812, 2.0923, 5.2187, 7.5184, 0.752, -14.9426, 10.0087, 4.391, 0.2918, 1.6715,
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
]
);
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;

View File

@ -60,7 +60,7 @@ bindgen_cuda = { version = "0.1.1", optional = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
cudnn = ["candle/cudnn", "candle-nn/cudnn", "candle-transformers/cudnn"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]

View File

@ -68,7 +68,7 @@ impl CustomOp1 for LayerNorm {
Some((o1, o2)) => slice.slice(o1..o2),
};
let elem_count = layout.shape().elem_count();
let dst = unsafe { dev.alloc::<f32>(elem_count) }?;
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
let func =
dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?;
let cfg = LaunchConfig {

View File

@ -46,7 +46,7 @@ impl TextGeneration {
Sampling::ArgMax
} else {
match (top_k, top_p) {
(None, None) => Sampling::GumbelSoftmax { temperature },
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },

View File

@ -1,14 +0,0 @@
# Orpheus
Orpheus is a 3B text-to-speech model based on Llama.
- Weights on HuggingFace
[canopylabs/orpheus-3b-0.1-ft](https://huggingface.co/canopylabs/orpheus-3b-0.1-ft).
- Code on GitHub [canopyai/Orpheus-TTS](https://github.com/canopyai/Orpheus-TTS).
```bash
cargo run --example orpheus --features cuda -r
```

View File

@ -1,329 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::llama::{Cache, Llama, LlamaConfig};
use candle_transformers::models::snac::{Config as SnacConfig, Model as SnacModel};
use tokenizers::Tokenizer;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/realtime_streaming_example/main.py#L43
const STOP_TOKEN_ID: u32 = 128258;
#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long, default_value = "Hey, how are you doing today?")]
prompt: String,
/// The temperature used to generate samples.
#[arg(long, default_value_t = 0.6)]
temperature: f64,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
model_file: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// The output wav file.
#[arg(long, default_value = "out.wav")]
out_file: String,
#[arg(long, default_value = "3b-0.1-ft")]
which: Which,
#[arg(long, default_value = "tara")]
voice: Voice,
#[arg(long)]
use_flash_attn: bool,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Voice {
#[value(name = "tara")]
Tara,
#[value(name = "leah")]
Leah,
#[value(name = "jess")]
Jess,
#[value(name = "leo")]
Leo,
#[value(name = "dan")]
Dan,
#[value(name = "mia")]
Mia,
#[value(name = "zac")]
Zac,
#[value(name = "zoe")]
Zoe,
}
impl Voice {
fn as_str(&self) -> &'static str {
match self {
Voice::Tara => "tara",
Voice::Leah => "leah",
Voice::Jess => "jess",
Voice::Leo => "leo",
Voice::Dan => "dan",
Voice::Mia => "mia",
Voice::Zac => "zac",
Voice::Zoe => "zoe",
}
}
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3b-0.1-ft")]
ThreeB0_1Ft,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let prompt = args.prompt.clone();
let mut model = Model::load(args)?;
model.run(&prompt)?;
Ok(())
}
struct Model {
model: Llama,
tokenizer: Tokenizer,
logits_processor: candle_transformers::generation::LogitsProcessor,
cache: Cache,
device: Device,
verbose_prompt: bool,
snac: SnacModel,
out_file: String,
voice: Voice,
}
fn load_snac(device: &Device) -> Result<SnacModel> {
let api = hf_hub::api::sync::Api::new()?;
let m = api.model("hubertsiuzdak/snac_24khz".to_string());
let config = m.get("config.json")?;
let config: SnacConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
let m = api.model("lmz/candle-snac".to_string());
let model = m.get("snac_24khz.safetensors")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, device)? };
let model = SnacModel::new(&config, vb)?;
Ok(model)
}
impl Model {
fn load(args: Args) -> Result<Self> {
let start = std::time::Instant::now();
let api = hf_hub::api::sync::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.which {
Which::ThreeB0_1Ft => "canopylabs/orpheus-3b-0.1-ft".to_string(),
},
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
));
let model_files = match args.model_file {
Some(m) => vec![m.into()],
None => match args.which {
Which::ThreeB0_1Ft => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
let config = match args.config_file {
Some(m) => m.into(),
None => repo.get("config.json")?,
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let dtype = device.bf16_default_to_f32();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, dtype, &device)? };
let config: LlamaConfig = serde_json::from_reader(std::fs::File::open(config)?)?;
let config = config.into_config(args.use_flash_attn);
let model = Llama::load(vb, &config)?;
let logits_processor = {
use candle_transformers::generation::{LogitsProcessor, Sampling};
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k.as_ref(), args.top_p.as_ref()) {
(None, None) => Sampling::All { temperature },
(Some(&k), None) => Sampling::TopK { k, temperature },
(None, Some(&p)) => Sampling::TopP { p, temperature },
(Some(&k), Some(&p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
println!("loaded the model in {:?}", start.elapsed());
let cache = Cache::new(true, dtype, &config, &device)?;
let snac = load_snac(&device)?;
Ok(Self {
model,
tokenizer,
logits_processor,
cache,
device,
verbose_prompt: args.verbose_prompt,
snac,
voice: args.voice,
out_file: args.out_file,
})
}
fn run(&mut self, prompt: &str) -> Result<()> {
println!("running the model on '{}'", prompt);
let device = &self.device;
let prompt = format!("{voice}: {prompt}", voice = self.voice.as_str());
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/engine_class.py#L82
let mut tokens = [
&[128259],
tokens.get_ids(),
&[128009, 128260, 128261, 128257],
]
.concat();
if self.verbose_prompt {
println!("{:?}", tokens);
}
let mut cache = self.cache.clone();
println!("starting the inference loop");
let mut index_pos = 0;
let mut audio_tokens = vec![];
for index in 0..2000 {
let (context_size, context_index) = if index > 0 {
(1, index_pos)
} else {
(tokens.len(), 0)
};
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, context_index, &mut cache)?;
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let next_token = self.logits_processor.sample(&logits)?;
if let Some(tok) = self.tokenizer.id_to_token(next_token) {
match tok.strip_prefix("<custom_token_") {
Some(tok) => match tok.strip_suffix('>') {
Some(tok) => {
let tok = tok.parse::<u32>()?;
// https://github.com/canopyai/Orpheus-TTS/blob/df0b0d96685dd21885aef7f900ee7f705c669e94/orpheus_tts_pypi/orpheus_tts/decoder.py#L86C35-L86C63
let tok = tok - 10 - ((audio_tokens.len() as u32 % 7) * 4096);
audio_tokens.push(tok);
}
None => {
println!("{index}: unexpected custom token {next_token} {tok}");
}
},
None => {
println!("{index}: unexpected token {next_token} {tok}");
}
}
}
if next_token == STOP_TOKEN_ID {
println!("reached stop token");
break;
}
tokens.push(next_token);
}
println!("generated {} audio tokens", audio_tokens.len());
let mut codes0 = vec![];
let mut codes1 = vec![];
let mut codes2 = vec![];
for audio_tokens in audio_tokens.chunks_exact(7) {
codes0.push(audio_tokens[0]);
for i in [1, 4] {
codes1.push(audio_tokens[i]);
}
for i in [2, 3, 5, 6] {
codes2.push(audio_tokens[i]);
}
}
let codes0 = Tensor::new(codes0, device)?.unsqueeze(0)?;
let codes1 = Tensor::new(codes1, device)?.unsqueeze(0)?;
let codes2 = Tensor::new(codes2, device)?.unsqueeze(0)?;
let pcm = self.snac.decode(&[&codes0, &codes1, &codes2])?;
println!("decoded to pcm {pcm:?}");
let mut output = std::fs::File::create(&self.out_file)?;
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24000)?;
Ok(())
}
}

View File

@ -133,7 +133,6 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl)
padding,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let conv = if bias {
conv2d(p, filters, size, conv_cfg, vb.pp(format!("conv_{index}")))?

View File

@ -92,7 +92,6 @@ impl ConvBlock {
stride,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.9.0-alpha.3"
version = "0.9.0-alpha.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,17 +11,14 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.3" }
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.1" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
bindgen_cuda = "0.1.1"
anyhow = { version = "1", features = ["backtrace"] }
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", features = ["cuda"] }
[features]
default = []
cudnn = ["candle/cudnn"]

View File

@ -2,6 +2,7 @@ mod ffi;
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
@ -141,8 +142,10 @@ impl FlashAttn {
let seqlen_k_rounded = round_multiple(seqlen_k, 128);
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count)? };
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)?;
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
.w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };
@ -604,8 +607,8 @@ impl FlashAttnVarLen {
let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128);
let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<f16>(elem_count)? };
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q)?;
let dst = unsafe { dev.alloc::<f16>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q).w()?;
let is_bf16 = if is_bf16 { 1 } else { 0 };

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.9.0-alpha.3"
version = "0.9.0-alpha.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -53,7 +53,7 @@ __device__ void conv1d(
template <typename T>
__device__ void im2col1d(
const size_t numel,
const size_t dst_numel,
const size_t l_out,
const size_t l_k,
const size_t stride,
@ -63,10 +63,10 @@ __device__ void im2col1d(
const T *src,
T *dst
) {
const size_t thread_i = blockIdx.x * blockDim.x + threadIdx.x;
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (thread_i >= numel) {
if (dst_i >= dst_numel) {
return;
}
const size_t *src_dims = info;
@ -74,26 +74,26 @@ __device__ void im2col1d(
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
const size_t dst_s1 = c_in;
const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = thread_i;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i;
for (size_t l_k_idx = 0; l_k_idx < l_k; ++l_k_idx) {
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
size_t dst_i = thread_i * l_k + l_k_idx;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
}
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[dst_i] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_l_idx * src_s[2];
dst[dst_i] = src[src_i];
}
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.9.0-alpha.3"
version = "0.9.0-alpha.1"
edition = "2021"
description = "Metal kernels for Candle"

View File

@ -33,7 +33,6 @@ criterion = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
cudnn = ["candle/cudnn"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]

View File

@ -1,6 +1,6 @@
//! Convolution Layers.
use crate::BatchNorm;
use candle::{conv::CudnnFwdAlgo, Result, Tensor};
use candle::{Result, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Conv1dConfig {
@ -8,7 +8,6 @@ pub struct Conv1dConfig {
pub stride: usize,
pub dilation: usize,
pub groups: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl Default for Conv1dConfig {
@ -18,7 +17,6 @@ impl Default for Conv1dConfig {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
}
}
}
@ -54,13 +52,12 @@ impl Conv1d {
impl crate::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d_with_algo(
let x = x.conv1d(
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
self.config.cudnn_fwd_algo,
)?;
match &self.bias {
None => Ok(x),
@ -150,7 +147,6 @@ pub struct Conv2dConfig {
pub stride: usize,
pub dilation: usize,
pub groups: usize,
pub cudnn_fwd_algo: Option<CudnnFwdAlgo>,
}
impl Default for Conv2dConfig {
@ -160,7 +156,6 @@ impl Default for Conv2dConfig {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
}
}
}
@ -216,13 +211,12 @@ impl Conv2d {
impl crate::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d_with_algo(
let x = x.conv2d(
&self.weight,
self.config.padding,
self.config.stride,
self.config.dilation,
self.config.groups,
self.config.cudnn_fwd_algo,
)?;
match &self.bias {
None => Ok(x),

View File

@ -31,7 +31,6 @@ pub mod ops;
pub mod optim;
pub mod rnn;
pub mod rotary_emb;
pub mod sampling;
pub mod sequential;
pub mod var_builder;
pub mod var_map;

View File

@ -41,36 +41,12 @@ impl Linear {
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
// When possible, we avoid using a broadcasted matmul as it is much slower
// than the standard matmul for the cuda and cpu backends.
let x = match *x.dims() {
[b1, b2, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((b1 * b2 * m, k))?
.matmul(&w)?
.reshape((b1, b2, m, ()))?
} else {
let w = self.weight.broadcast_left((b1, b2))?.t()?;
x.matmul(&w)?
}
}
[bsize, m, k] => {
if x.is_contiguous() {
let w = self.weight.t()?;
x.reshape((bsize * m, k))?
.matmul(&w)?
.reshape((bsize, m, ()))?
} else {
let w = self.weight.broadcast_left(bsize)?.t()?;
x.matmul(&w)?
}
}
_ => {
let w = self.weight.t()?;
x.matmul(&w)?
}
let w = match *x.dims() {
[b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?,
[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),

View File

@ -112,7 +112,7 @@ impl candle::CustomOp1 for Sigmoid {
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), &kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el_count)? };
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let mut builder = func.builder();
candle::builder_arg!(builder, el_count, dims.len());
@ -373,7 +373,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
let func = dev.get_or_load_func(&kernel_name::<T>("softmax"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
@ -561,7 +561,7 @@ impl candle::CustomOp2 for RmsNorm {
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);
@ -800,7 +800,7 @@ impl candle::CustomOp3 for LayerNorm {
let func =
dev.get_or_load_func(&kernel_name::<T>("layernorm"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&dst);

View File

@ -119,7 +119,7 @@ impl candle::CustomOp3 for RotaryEmbI {
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_i"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
@ -369,7 +369,7 @@ impl candle::CustomOp3 for RotaryEmb {
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);
@ -620,7 +620,7 @@ impl candle::CustomOp3 for RotaryEmbThd {
let cfg = LaunchConfig::for_num_elems((el / 2) as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("rope_thd"), &kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el)? };
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let mut builder = func.builder();
builder.arg(&src);
builder.arg(&cos);

View File

@ -1,20 +0,0 @@
use candle::{Result, Tensor};
/// Sample according to the Gumbel-Softmax distribution.
pub fn gumbel_softmax<D: candle::shape::Dim>(
logits: &Tensor,
temperature: f64,
dim: D,
) -> Result<Tensor> {
if temperature <= 0.0 {
logits.argmax(dim)
} else if temperature == 1.0 {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits - minus_g)?.argmax(dim)?;
Ok(sampled)
} else {
let minus_g = logits.rand_like(1e-7, 0.999)?.log()?.neg()?.log()?;
let sampled = (logits + minus_g * (-temperature))?.argmax(dim)?;
Ok(sampled)
}
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.9.0-alpha.3"
version = "0.9.0-alpha.1"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.3" }
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.3" }
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.1" }
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.1" }
prost = "0.12.1"
[build-dependencies]

View File

@ -29,7 +29,6 @@ tracing = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
cudnn = ["candle/cudnn", "candle-nn/cudnn"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
metal = ["candle/metal", "candle-nn/metal"]

View File

@ -13,8 +13,6 @@ pub enum Sampling {
TopK { k: usize, temperature: f64 },
TopP { p: f64, temperature: f64 },
TopKThenTopP { k: usize, p: f64, temperature: f64 },
// Note that the rng is not used for the Gumbel-Softmax sampling.
GumbelSoftmax { temperature: f64 },
}
pub struct LogitsProcessor {
@ -51,11 +49,6 @@ impl LogitsProcessor {
Ok(next_token)
}
fn sample_gumbel_softmax(&mut self, logits: &Tensor, temperature: f64) -> Result<u32> {
let sampled = candle_nn::sampling::gumbel_softmax(logits, temperature, candle::D::Minus1)?;
sampled.to_vec0::<u32>()
}
fn sample_multinomial(&mut self, prs: &Vec<f32>) -> Result<u32> {
let distr = rand::distr::weighted::WeightedIndex::new(prs).map_err(Error::wrap)?;
let next_token = distr.sample(&mut self.rng) as u32;
@ -134,9 +127,6 @@ impl LogitsProcessor {
let next_token = match &self.sampling {
Sampling::ArgMax => self.sample_argmax(logits)?,
Sampling::GumbelSoftmax { temperature } => {
self.sample_gumbel_softmax(&logits, *temperature)?
}
Sampling::All { temperature } => {
let prs = prs(*temperature)?;
self.sample_multinomial(&prs)?

View File

@ -124,7 +124,6 @@ impl ResidualConvUnit {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
};
let conv1 = conv2d(
conf.num_features,
@ -209,7 +208,6 @@ impl FeatureFusionBlock {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
};
let output_conv = conv2d(
conf.num_features,
@ -260,7 +258,6 @@ impl Scratch {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
};
let layer1_rn = conv2d_no_bias(
@ -322,7 +319,6 @@ impl Scratch {
stride: 1,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
};
let output_conv1 = conv2d(
conf.num_features,
@ -429,7 +425,6 @@ impl DPTHead {
stride: 2,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
},
vb.pp("resize_layers").pp("3"),
)?),

View File

@ -468,7 +468,6 @@ impl EncodecConv1d {
stride,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
},
vb.pp("conv"),
)?,

View File

@ -267,7 +267,6 @@ impl StreamableConv1d {
stride,
dilation,
groups,
cudnn_fwd_algo: None,
};
let conv = NormConv1d::new(in_c, out_c, k_size, causal, norm, bias, cfg, vb)?;
if k_size < stride {

View File

@ -68,7 +68,6 @@ impl ResnetBlock2D {
padding: 1,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
@ -84,7 +83,6 @@ impl ResnetBlock2D {
padding: 0,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
Some(conv2d(
in_channels,

View File

@ -248,14 +248,12 @@ impl AudioEncoder {
stride: 1,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;

View File

@ -244,14 +244,12 @@ impl AudioEncoder {
stride: 1,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let cfg2 = Conv1dConfig {
padding: 1,
stride: 2,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;

View File

@ -54,25 +54,3 @@ fn sample_with_top_k() -> Result<()> {
assert_eq!(token, 2);
Ok(())
}
#[test]
fn sample_gumbel() -> Result<()> {
let mut logits_process = LogitsProcessor::from_sampling(
42,
candle_transformers::generation::Sampling::GumbelSoftmax { temperature: 1.0 },
);
let logits = Tensor::new(&[-1.0, 0.0, 0.2, 1.0], &Device::Cpu)?;
let sm = candle_nn::ops::softmax(&logits, 0)?.to_vec1::<f64>()?;
let mut counts = vec![0f64; 4];
let samples = 100000;
for _ in 0..samples {
let token = logits_process.sample(&logits)?;
counts[token as usize] += 1f64 / samples as f64;
}
for i in 0..4 {
if (counts[i] - sm[i]).abs() > 0.05 {
panic!("pr mismatch {counts:?} {sm:?}");
}
}
Ok(())
}

View File

@ -98,7 +98,6 @@ impl ConvBlock {
stride,
groups: 1,
dilation: 1,
cudnn_fwd_algo: None,
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;