From 665da304878326e267b178fa6e6d85424249126b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 29 Mar 2024 23:02:11 +0100 Subject: [PATCH] Backend refactoring. (#1966) * Backend refactoring. * Metal tweaks. * Move the cudnn module. --- .../{cpu_backend.rs => cpu_backend/mod.rs} | 370 +----------- candle-core/src/cpu_backend/utils.rs | 350 ++++++++++++ candle-core/src/{ => cuda_backend}/cudnn.rs | 0 candle-core/src/cuda_backend/device.rs | 410 +++++++++++++ .../{cuda_backend.rs => cuda_backend/mod.rs} | 539 +----------------- candle-core/src/cuda_backend/utils.rs | 134 +++++ candle-core/src/lib.rs | 5 +- candle-core/src/metal_backend/device.rs | 287 ++++++++++ .../mod.rs} | 291 +--------- 9 files changed, 1202 insertions(+), 1184 deletions(-) rename candle-core/src/{cpu_backend.rs => cpu_backend/mod.rs} (87%) create mode 100644 candle-core/src/cpu_backend/utils.rs rename candle-core/src/{ => cuda_backend}/cudnn.rs (100%) create mode 100644 candle-core/src/cuda_backend/device.rs rename candle-core/src/{cuda_backend.rs => cuda_backend/mod.rs} (78%) create mode 100644 candle-core/src/cuda_backend/utils.rs create mode 100644 candle-core/src/metal_backend/device.rs rename candle-core/src/{metal_backend.rs => metal_backend/mod.rs} (86%) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend/mod.rs similarity index 87% rename from candle-core/src/cpu_backend.rs rename to candle-core/src/cpu_backend/mod.rs index 6d2ba361..d686440a 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -4,6 +4,11 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; use rayon::prelude::*; +mod utils; +pub use utils::{ + binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8, +}; + const USE_IM2COL_CONV1D: bool = true; const USE_IM2COL_CONV1D_TR: bool = true; const USE_IM2COL_CONV2D: bool = true; @@ -24,102 +29,6 @@ pub enum CpuStorage { #[derive(Debug, Clone)] pub struct CpuDevice; -pub trait Map1 { - fn f(&self, vs: &[T], layout: &Layout) -> Result>; - - fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { - match vs { - CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)), - CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)), - CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)), - CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)), - CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)), - CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)), - CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)), - } - } -} - -pub trait Map1Any { - fn f) -> CpuStorage>( - &self, - vs: &[T], - layout: &Layout, - wrap: W, - ) -> Result; - - fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { - match vs { - CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?), - CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?), - CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?), - CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?), - CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?), - CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?), - CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?), - } - } -} - -type C = CpuStorage; -pub trait Map2 { - const OP: &'static str; - fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; - - fn map( - &self, - v1: &CpuStorage, - l1: &Layout, - v2: &CpuStorage, - l2: &Layout, - ) -> Result { - match (v1, v2) { - (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), - (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)), - (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), - (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), - (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), - (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), - _ => Err(Error::DTypeMismatchBinaryOp { - lhs: v1.dtype(), - rhs: v2.dtype(), - op: Self::OP, - } - .bt()), - } - } -} - -pub trait Map2U8 { - const OP: &'static str; - fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; - - fn map( - &self, - v1: &CpuStorage, - l1: &Layout, - v2: &CpuStorage, - l2: &Layout, - ) -> Result { - match (v1, v2) { - (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), - _ => Err(Error::DTypeMismatchBinaryOp { - lhs: v1.dtype(), - rhs: v2.dtype(), - op: Self::OP, - } - .bt()), - } - } -} - struct Cmp(CmpOp); impl Map2U8 for Cmp { const OP: &'static str = "cmp"; @@ -366,275 +275,6 @@ impl<'a> Map1 for ReduceSum<'a> { } } -pub fn unary_map U>( - vs: &[T], - layout: &Layout, - mut f: F, -) -> Vec { - match layout.strided_blocks() { - crate::StridedBlocks::SingleBlock { start_offset, len } => vs - [start_offset..start_offset + len] - .iter() - .map(|&v| f(v)) - .collect(), - crate::StridedBlocks::MultipleBlocks { - block_start_index, - block_len, - } => { - let mut result = Vec::with_capacity(layout.shape().elem_count()); - // Specialize the case where block_len is one to avoid the second loop. - if block_len == 1 { - for index in block_start_index { - let v = unsafe { vs.get_unchecked(index) }; - result.push(f(*v)) - } - } else { - for index in block_start_index { - for offset in 0..block_len { - let v = unsafe { vs.get_unchecked(index + offset) }; - result.push(f(*v)) - } - } - } - result - } - } -} - -pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( - vs: &[T], - layout: &Layout, - mut f: F, - mut f_vec: FV, -) -> Vec { - match layout.strided_blocks() { - crate::StridedBlocks::SingleBlock { start_offset, len } => { - let mut ys: Vec = Vec::with_capacity(len); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; - f_vec(&vs[start_offset..start_offset + len], ys_to_set); - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(len) }; - ys - } - crate::StridedBlocks::MultipleBlocks { - block_start_index, - block_len, - } => { - let el_count = layout.shape().elem_count(); - // Specialize the case where block_len is one to avoid the second loop. - if block_len == 1 { - let mut result = Vec::with_capacity(el_count); - for index in block_start_index { - let v = unsafe { vs.get_unchecked(index) }; - result.push(f(*v)) - } - result - } else { - let mut ys: Vec = Vec::with_capacity(el_count); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; - let mut dst_index = 0; - for src_index in block_start_index { - let vs = &vs[src_index..src_index + block_len]; - let ys = &mut ys_to_set[dst_index..dst_index + block_len]; - f_vec(vs, ys); - dst_index += block_len; - } - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(el_count) }; - ys - } - } - } -} - -// This function maps over two strided index sequences. -pub fn binary_map U>( - lhs_l: &Layout, - rhs_l: &Layout, - lhs: &[T], - rhs: &[T], - mut f: F, -) -> Vec { - match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { - (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] - .iter() - .zip(rhs[o_r1..o_r2].iter()) - .map(|(&l, &r)| f(l, r)) - .collect(), - (Some((o_l1, o_l2)), None) => { - // TODO: Maybe we want to avoid going through the layout twice. - match rhs_l.offsets_b() { - Some(ob) => { - let mut i_in_block = 0; - let mut i_right_broadcast = 0; - lhs[o_l1..o_l2] - .iter() - .map(|&l| { - let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) }; - i_right_broadcast += 1; - if i_right_broadcast >= ob.right_broadcast { - i_in_block += 1; - i_right_broadcast = 0; - } - if i_in_block >= ob.len { - i_in_block = 0 - } - f(l, *r) - }) - .collect() - } - None => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - } - } - (None, Some((o_r1, o_r2))) => { - // TODO: Maybe we want to avoid going through the layout twice. - match lhs_l.offsets_b() { - Some(ob) => { - let mut i_in_block = 0; - let mut i_right_broadcast = 0; - rhs[o_r1..o_r2] - .iter() - .map(|&r| { - let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) }; - i_right_broadcast += 1; - if i_right_broadcast >= ob.right_broadcast { - i_in_block += 1; - i_right_broadcast = 0; - } - if i_in_block >= ob.len { - i_in_block = 0 - } - f(*l, r) - }) - .collect() - } - None => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - } - } - _ => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - } -} - -// Similar to binary_map but with vectorized variants. -pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>( - lhs_l: &Layout, - rhs_l: &Layout, - lhs: &[T], - rhs: &[T], - mut f: F, - mut f_vec: FV, -) -> Vec { - let el_count = lhs_l.shape().elem_count(); - match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { - (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { - let mut ys: Vec = Vec::with_capacity(el_count); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; - f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(el_count) }; - ys - } - (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() { - Some(ob) if ob.right_broadcast == 1 => { - let rhs = &rhs[ob.start..ob.start + ob.len]; - let mut ys: Vec = Vec::with_capacity(el_count); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; - let mut dst_i = 0; - for src_i in (o_l1..o_l2).step_by(ob.len) { - f_vec( - &lhs[src_i..src_i + ob.len], - rhs, - &mut ys_to_set[dst_i..dst_i + ob.len], - ); - dst_i += ob.len; - } - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(el_count) }; - ys - } - Some(ob) => { - let rhs = &rhs[ob.start..ob.start + ob.len]; - let mut ys = lhs[o_l1..o_l2].to_vec(); - for idx_l in 0..ob.left_broadcast { - let start = idx_l * ob.len * ob.right_broadcast; - for (i, &r) in rhs.iter().enumerate() { - let start = start + i * ob.right_broadcast; - for v in ys[start..start + ob.right_broadcast].iter_mut() { - *v = f(*v, r) - } - } - } - ys - } - None => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - }, - (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() { - Some(ob) if ob.right_broadcast == 1 => { - let lhs = &lhs[ob.start..ob.start + ob.len]; - let mut ys: Vec = Vec::with_capacity(el_count); - let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; - let mut dst_i = 0; - for src_i in (o_r1..o_r2).step_by(ob.len) { - f_vec( - lhs, - &rhs[src_i..src_i + ob.len], - &mut ys_to_set[dst_i..dst_i + ob.len], - ); - dst_i += ob.len; - } - // SAFETY: values are all set by f_vec. - unsafe { ys.set_len(el_count) }; - ys - } - Some(ob) => { - let lhs = &lhs[ob.start..ob.start + ob.len]; - let mut ys = rhs[o_r1..o_r2].to_vec(); - for idx_l in 0..ob.left_broadcast { - let start = idx_l * ob.len * ob.right_broadcast; - for (i, &l) in lhs.iter().enumerate() { - let start = start + i * ob.right_broadcast; - for v in ys[start..start + ob.right_broadcast].iter_mut() { - *v = f(l, *v) - } - } - } - ys - } - None => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - }, - _ => lhs_l - .strided_index() - .zip(rhs_l.strided_index()) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect(), - } -} - struct Affine(f64, f64); impl Map1 for Affine { diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs new file mode 100644 index 00000000..af25a2af --- /dev/null +++ b/candle-core/src/cpu_backend/utils.rs @@ -0,0 +1,350 @@ +/// Helper functions to write CPU kernels. +use crate::backend::BackendStorage; +use crate::{Error, Layout, Result, WithDType}; + +type C = super::CpuStorage; +pub trait Map1 { + fn f(&self, vs: &[T], layout: &Layout) -> Result>; + + fn map(&self, vs: &C, layout: &Layout) -> Result { + match vs { + C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), + C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), + C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), + C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), + C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)), + C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)), + } + } +} + +pub trait Map1Any { + fn f) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result; + + fn map(&self, vs: &C, layout: &Layout) -> Result { + match vs { + C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), + C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), + C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), + C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), + C::F32(vs) => Ok(self.f(vs, layout, C::F32)?), + C::F64(vs) => Ok(self.f(vs, layout, C::F64)?), + } + } +} + +pub trait Map2 { + const OP: &'static str; + fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; + + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub trait Map2U8 { + const OP: &'static str; + fn f(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result>; + + fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result { + match (v1, v2) { + (C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + } + .bt()), + } + } +} + +pub fn binary_map U>( + lhs_l: &Layout, + rhs_l: &Layout, + lhs: &[T], + rhs: &[T], + mut f: F, +) -> Vec { + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] + .iter() + .zip(rhs[o_r1..o_r2].iter()) + .map(|(&l, &r)| f(l, r)) + .collect(), + (Some((o_l1, o_l2)), None) => { + // TODO: Maybe we want to avoid going through the layout twice. + match rhs_l.offsets_b() { + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + lhs[o_l1..o_l2] + .iter() + .map(|&l| { + let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(l, *r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } + } + (None, Some((o_r1, o_r2))) => { + // TODO: Maybe we want to avoid going through the layout twice. + match lhs_l.offsets_b() { + Some(ob) => { + let mut i_in_block = 0; + let mut i_right_broadcast = 0; + rhs[o_r1..o_r2] + .iter() + .map(|&r| { + let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) }; + i_right_broadcast += 1; + if i_right_broadcast >= ob.right_broadcast { + i_in_block += 1; + i_right_broadcast = 0; + } + if i_in_block >= ob.len { + i_in_block = 0 + } + f(*l, r) + }) + .collect() + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } + } + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } +} + +// Similar to binary_map but with vectorized variants. +pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])>( + lhs_l: &Layout, + rhs_l: &Layout, + lhs: &[T], + rhs: &[T], + mut f: F, + mut f_vec: FV, +) -> Vec { + let el_count = lhs_l.shape().elem_count(); + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() { + Some(ob) if ob.right_broadcast == 1 => { + let rhs = &rhs[ob.start..ob.start + ob.len]; + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let mut dst_i = 0; + for src_i in (o_l1..o_l2).step_by(ob.len) { + f_vec( + &lhs[src_i..src_i + ob.len], + rhs, + &mut ys_to_set[dst_i..dst_i + ob.len], + ); + dst_i += ob.len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + Some(ob) => { + let rhs = &rhs[ob.start..ob.start + ob.len]; + let mut ys = lhs[o_l1..o_l2].to_vec(); + for idx_l in 0..ob.left_broadcast { + let start = idx_l * ob.len * ob.right_broadcast; + for (i, &r) in rhs.iter().enumerate() { + let start = start + i * ob.right_broadcast; + for v in ys[start..start + ob.right_broadcast].iter_mut() { + *v = f(*v, r) + } + } + } + ys + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + }, + (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() { + Some(ob) if ob.right_broadcast == 1 => { + let lhs = &lhs[ob.start..ob.start + ob.len]; + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let mut dst_i = 0; + for src_i in (o_r1..o_r2).step_by(ob.len) { + f_vec( + lhs, + &rhs[src_i..src_i + ob.len], + &mut ys_to_set[dst_i..dst_i + ob.len], + ); + dst_i += ob.len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + Some(ob) => { + let lhs = &lhs[ob.start..ob.start + ob.len]; + let mut ys = rhs[o_r1..o_r2].to_vec(); + for idx_l in 0..ob.left_broadcast { + let start = idx_l * ob.len * ob.right_broadcast; + for (i, &l) in lhs.iter().enumerate() { + let start = start + i * ob.right_broadcast; + for v in ys[start..start + ob.right_broadcast].iter_mut() { + *v = f(l, *v) + } + } + } + ys + } + None => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + }, + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect(), + } +} + +pub fn unary_map U>( + vs: &[T], + layout: &Layout, + mut f: F, +) -> Vec { + match layout.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => vs + [start_offset..start_offset + len] + .iter() + .map(|&v| f(v)) + .collect(), + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let mut result = Vec::with_capacity(layout.shape().elem_count()); + // Specialize the case where block_len is one to avoid the second loop. + if block_len == 1 { + for index in block_start_index { + let v = unsafe { vs.get_unchecked(index) }; + result.push(f(*v)) + } + } else { + for index in block_start_index { + for offset in 0..block_len { + let v = unsafe { vs.get_unchecked(index + offset) }; + result.push(f(*v)) + } + } + } + result + } + } +} + +pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( + vs: &[T], + layout: &Layout, + mut f: F, + mut f_vec: FV, +) -> Vec { + match layout.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => { + let mut ys: Vec = Vec::with_capacity(len); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + f_vec(&vs[start_offset..start_offset + len], ys_to_set); + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(len) }; + ys + } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let el_count = layout.shape().elem_count(); + // Specialize the case where block_len is one to avoid the second loop. + if block_len == 1 { + let mut result = Vec::with_capacity(el_count); + for index in block_start_index { + let v = unsafe { vs.get_unchecked(index) }; + result.push(f(*v)) + } + result + } else { + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + let mut dst_index = 0; + for src_index in block_start_index { + let vs = &vs[src_index..src_index + block_len]; + let ys = &mut ys_to_set[dst_index..dst_index + block_len]; + f_vec(vs, ys); + dst_index += block_len; + } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys + } + } + } +} diff --git a/candle-core/src/cudnn.rs b/candle-core/src/cuda_backend/cudnn.rs similarity index 100% rename from candle-core/src/cudnn.rs rename to candle-core/src/cuda_backend/cudnn.rs diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs new file mode 100644 index 00000000..0859d756 --- /dev/null +++ b/candle-core/src/cuda_backend/device.rs @@ -0,0 +1,410 @@ +use crate::backend::BackendDevice; +use crate::{CpuStorage, DType, Layout, Result, Shape}; +pub use candle_kernels as kernels; +pub use cudarc; +use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig}; +use half::{bf16, f16}; +use std::sync::{Arc, Mutex}; + +use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr}; + +/// Unique identifier for cuda devices. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct DeviceId(usize); + +impl DeviceId { + fn new() -> Self { + // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 + use std::sync::atomic; + static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); + Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) + } +} + +struct CudaRng(cudarc::curand::CudaRng); +unsafe impl Send for CudaRng {} + +#[derive(Clone)] +pub struct CudaDevice { + id: DeviceId, + device: Arc, + pub(crate) blas: Arc, + curand: Arc>, +} + +impl std::fmt::Debug for CudaDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CudaDevice({:?})", self.id) + } +} + +impl std::ops::Deref for CudaDevice { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl CudaDevice { + pub fn cuda_device(&self) -> Arc { + self.device.clone() + } + + pub fn id(&self) -> DeviceId { + self.id + } + + fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(elem_count as u32); + let slice = match dtype { + DType::U8 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_u8", kernels::FILL)?; + let params = (&data, v as u8, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_u32", kernels::FILL)?; + let params = (&data, v as u32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::U32(data) + } + DType::I64 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i64", kernels::FILL)?; + let params = (&data, v as i64, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; + let params = (&data, bf16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f16", kernels::FILL)?; + let params = (&data, f16::from_f64(v), elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f32", kernels::FILL)?; + let params = (&data, v as f32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_f64", kernels::FILL)?; + let params = (&data, v, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { + if !self.has_func(module_name, module_name) { + // Leaking the string here is a bit sad but we need a &'static str and this is only + // done once per kernel name. + let static_module_name = Box::leak(module_name.to_string().into_boxed_str()); + self.load_ptx(ptx.into(), module_name, &[static_module_name]) + .map_err(|cuda| CudaError::Load { + cuda, + module_name: module_name.to_string(), + }) + .w()?; + } + self.get_func(module_name, module_name) + // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is + // able to only build the error value if needed. + .ok_or(CudaError::MissingKernel { + module_name: module_name.to_string(), + }) + .w() + } +} + +impl BackendDevice for CudaDevice { + type Storage = CudaStorage; + + fn new(ordinal: usize) -> Result { + let device = cudarc::driver::CudaDevice::new(ordinal).w()?; + let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + Ok(Self { + id: DeviceId::new(), + device, + blas: Arc::new(blas), + curand: Arc::new(Mutex::new(CudaRng(curand))), + }) + } + + fn set_seed(&self, seed: u64) -> Result<()> { + // We do not call set_seed but instead create a new curand object. This ensures that the + // state will be identical and the same random numbers will be generated. + let mut curand = self.curand.lock().unwrap(); + curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; + Ok(()) + } + + fn location(&self) -> crate::DeviceLocation { + crate::DeviceLocation::Cuda { + gpu_id: self.device.ordinal(), + } + } + + fn same_device(&self, rhs: &Self) -> bool { + self.id == rhs.id + } + + fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let slice = match dtype { + DType::U8 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::I64 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result { + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + let slice = match dtype { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_uniform", + }) + .w()? + } + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand.0.fill_with_uniform(&mut data).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count) }.w()?; + curand.0.fill_with_uniform(&mut data).w()?; + CudaStorageSlice::F64(data) + } + }; + let slice = if lo == 0. && up == 1.0 { + slice + } else { + use super::utils::Map1; + let layout = Layout::contiguous(shape); + super::Affine(up - lo, lo).map(&slice, self, &layout)? + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { + // TODO: Add support for F16 and BF16 though this is likely to require some upstream + // cudarc changes. + let elem_count = shape.elem_count(); + let curand = self.curand.lock().unwrap(); + // curand can only generate an odd number of values. + // https://github.com/huggingface/candle/issues/734 + let elem_count_round = if elem_count % 2 == 1 { + elem_count + 1 + } else { + elem_count + }; + let slice = match dtype { + DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + Err(CudaError::UnsupportedDtype { + dtype, + op: "rand_normal", + }) + .w()? + } + DType::F32 => { + let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; + curand + .0 + .fill_with_normal(&mut data, mean as f32, std as f32) + .w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; + curand.0.fill_with_normal(&mut data, mean, std).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { + self.const_impl(1., shape, dtype) + } + + unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { + let elem_count = shape.elem_count(); + let slice = match dtype { + DType::U8 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::U8(data) + } + DType::U32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::U32(data) + } + DType::I64 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I64(data) + } + DType::BF16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::BF16(data) + } + DType::F16 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F16(data) + } + DType::F32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F32(data) + } + DType::F64 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { + let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorage::U32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorage::I64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorage::BF16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorage::F32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorage::F64(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } + + fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result { + let slice = match storage { + CpuStorage::U8(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U8(data) + } + CpuStorage::U32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::U32(data) + } + CpuStorage::I64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I64(data) + } + CpuStorage::BF16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::BF16(data) + } + CpuStorage::F16(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F16(data) + } + CpuStorage::F32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F32(data) + } + CpuStorage::F64(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::F64(data) + } + }; + Ok(CudaStorage { + slice, + device: self.clone(), + }) + } +} diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend/mod.rs similarity index 78% rename from candle-core/src/cuda_backend.rs rename to candle-core/src/cuda_backend/mod.rs index 23487330..78aebd9b 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -5,11 +5,17 @@ pub use candle_kernels as kernels; pub use cudarc; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{ - CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, - ValidAsZeroBits, + CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits, }; use half::{bf16, f16}; -use std::sync::{Arc, Mutex}; + +mod device; +pub use device::{CudaDevice, DeviceId}; +mod utils; +pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S}; + +#[cfg(feature = "cudnn")] +pub mod cudnn; enum SlicePtrOrNull { Ptr(CudaSlice), @@ -87,44 +93,6 @@ impl From for crate::Error { } } -/// Unique identifier for cuda devices. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct DeviceId(usize); - -impl DeviceId { - fn new() -> Self { - // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 - use std::sync::atomic; - static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); - Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) - } -} - -struct CudaRng(cudarc::curand::CudaRng); -unsafe impl Send for CudaRng {} - -#[derive(Clone)] -pub struct CudaDevice { - id: DeviceId, - device: Arc, - blas: Arc, - curand: Arc>, -} - -impl std::fmt::Debug for CudaDevice { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "CudaDevice({:?})", self.id) - } -} - -impl std::ops::Deref for CudaDevice { - type Target = Arc; - - fn deref(&self) -> &Self::Target { - &self.device - } -} - pub trait WrapErr { fn w(self) -> std::result::Result; } @@ -135,368 +103,6 @@ impl> WrapErr for std::result::Result { } } -impl CudaDevice { - pub fn cuda_device(&self) -> Arc { - self.device.clone() - } - - pub fn id(&self) -> DeviceId { - self.id - } - - fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result { - let elem_count = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(elem_count as u32); - let slice = match dtype { - DType::U8 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_u8", kernels::FILL)?; - let params = (&data, v as u8, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::U8(data) - } - DType::U32 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_u32", kernels::FILL)?; - let params = (&data, v as u32, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::U32(data) - } - DType::I64 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_i64", kernels::FILL)?; - let params = (&data, v as i64, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::I64(data) - } - DType::BF16 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_bf16", kernels::FILL)?; - let params = (&data, bf16::from_f64(v), elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::BF16(data) - } - DType::F16 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_f16", kernels::FILL)?; - let params = (&data, f16::from_f64(v), elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::F16(data) - } - DType::F32 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_f32", kernels::FILL)?; - let params = (&data, v as f32, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; - let func = self.get_or_load_func("fill_f64", kernels::FILL)?; - let params = (&data, v, elem_count); - unsafe { func.launch(cfg, params) }.w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result { - if !self.has_func(module_name, module_name) { - // Leaking the string here is a bit sad but we need a &'static str and this is only - // done once per kernel name. - let static_module_name = Box::leak(module_name.to_string().into_boxed_str()); - self.load_ptx(ptx.into(), module_name, &[static_module_name]) - .map_err(|cuda| CudaError::Load { - cuda, - module_name: module_name.to_string(), - }) - .w()?; - } - self.get_func(module_name, module_name) - // Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is - // able to only build the error value if needed. - .ok_or(CudaError::MissingKernel { - module_name: module_name.to_string(), - }) - .w() - } -} - -impl BackendDevice for CudaDevice { - type Storage = CudaStorage; - - fn new(ordinal: usize) -> Result { - let device = cudarc::driver::CudaDevice::new(ordinal).w()?; - let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; - let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; - Ok(Self { - id: DeviceId::new(), - device, - blas: Arc::new(blas), - curand: Arc::new(Mutex::new(CudaRng(curand))), - }) - } - - fn set_seed(&self, seed: u64) -> Result<()> { - // We do not call set_seed but instead create a new curand object. This ensures that the - // state will be identical and the same random numbers will be generated. - let mut curand = self.curand.lock().unwrap(); - curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?; - Ok(()) - } - - fn location(&self) -> crate::DeviceLocation { - crate::DeviceLocation::Cuda { - gpu_id: self.device.ordinal(), - } - } - - fn same_device(&self, rhs: &Self) -> bool { - self.id == rhs.id - } - - fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result { - let elem_count = shape.elem_count(); - let slice = match dtype { - DType::U8 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::U8(data) - } - DType::U32 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::U32(data) - } - DType::I64 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::I64(data) - } - DType::BF16 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::BF16(data) - } - DType::F16 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::F16(data) - } - DType::F32 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let data = self.alloc_zeros::(elem_count).w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result { - let elem_count = shape.elem_count(); - let curand = self.curand.lock().unwrap(); - let slice = match dtype { - // TODO: Add support for F16 and BF16 though this is likely to require some upstream - // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_uniform", - }) - .w()? - } - DType::F32 => { - let mut data = unsafe { self.alloc::(elem_count) }.w()?; - curand.0.fill_with_uniform(&mut data).w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let mut data = unsafe { self.alloc::(elem_count) }.w()?; - curand.0.fill_with_uniform(&mut data).w()?; - CudaStorageSlice::F64(data) - } - }; - let slice = if lo == 0. && up == 1.0 { - slice - } else { - let layout = Layout::contiguous(shape); - Affine(up - lo, lo).map(&slice, self, &layout)? - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result { - // TODO: Add support for F16 and BF16 though this is likely to require some upstream - // cudarc changes. - let elem_count = shape.elem_count(); - let curand = self.curand.lock().unwrap(); - // curand can only generate an odd number of values. - // https://github.com/huggingface/candle/issues/734 - let elem_count_round = if elem_count % 2 == 1 { - elem_count + 1 - } else { - elem_count - }; - let slice = match dtype { - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { - Err(CudaError::UnsupportedDtype { - dtype, - op: "rand_normal", - }) - .w()? - } - DType::F32 => { - let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; - curand - .0 - .fill_with_normal(&mut data, mean as f32, std as f32) - .w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; - curand.0.fill_with_normal(&mut data, mean, std).w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result { - self.const_impl(1., shape, dtype) - } - - unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result { - let elem_count = shape.elem_count(); - let slice = match dtype { - DType::U8 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::U8(data) - } - DType::U32 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::U32(data) - } - DType::I64 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::I64(data) - } - DType::BF16 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::BF16(data) - } - DType::F16 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::F16(data) - } - DType::F32 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::F32(data) - } - DType::F64 => { - let data = self.alloc::(elem_count).w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { - let slice = match storage { - CpuStorage::U8(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::U8(data) - } - CpuStorage::U32(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::U32(data) - } - CpuStorage::I64(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::I64(data) - } - CpuStorage::BF16(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::BF16(data) - } - CpuStorage::F16(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::F16(data) - } - CpuStorage::F32(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::F32(data) - } - CpuStorage::F64(storage) => { - let data = self.htod_sync_copy(storage).w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } - - fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result { - let slice = match storage { - CpuStorage::U8(storage) => { - let data = self.htod_copy(storage).w()?; - CudaStorageSlice::U8(data) - } - CpuStorage::U32(storage) => { - let data = self.htod_copy(storage).w()?; - CudaStorageSlice::U32(data) - } - CpuStorage::I64(storage) => { - let data = self.htod_copy(storage).w()?; - CudaStorageSlice::I64(data) - } - CpuStorage::BF16(storage) => { - let data = self.htod_copy(storage).w()?; - CudaStorageSlice::BF16(data) - } - CpuStorage::F16(storage) => { - let data = self.htod_copy(storage).w()?; - CudaStorageSlice::F16(data) - } - CpuStorage::F32(storage) => { - let data = self.htod_copy(storage).w()?; - CudaStorageSlice::F32(data) - } - CpuStorage::F64(storage) => { - let data = self.htod_copy(storage).w()?; - CudaStorageSlice::F64(data) - } - }; - Ok(CudaStorage { - slice, - device: self.clone(), - }) - } -} - #[derive(Debug)] pub enum CudaStorageSlice { U8(CudaSlice), @@ -507,133 +113,6 @@ pub enum CudaStorageSlice { F32(CudaSlice), F64(CudaSlice), } -type S = CudaStorageSlice; - -pub trait Map1 { - fn f( - &self, - src: &CudaSlice, - dev: &CudaDevice, - layout: &Layout, - ) -> Result>; - - fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { - let out = match s { - S::U8(s) => S::U8(self.f(s, d, l)?), - S::U32(s) => S::U32(self.f(s, d, l)?), - S::I64(s) => S::I64(self.f(s, d, l)?), - S::BF16(s) => S::BF16(self.f(s, d, l)?), - S::F16(s) => S::F16(self.f(s, d, l)?), - S::F32(s) => S::F32(self.f(s, d, l)?), - S::F64(s) => S::F64(self.f(s, d, l)?), - }; - Ok(out) - } -} - -pub trait Map2 { - fn f( - &self, - src1: &CudaSlice, - layout1: &Layout, - src2: &CudaSlice, - layout2: &Layout, - dev: &CudaDevice, - ) -> Result>; - - fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { - let out = match (s1, s2) { - (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), - (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), - (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?), - (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), - (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), - (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), - (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), - _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, - }; - Ok(out) - } -} - -pub trait Map2InPlace { - fn f( - &self, - dst: &mut CudaSlice, - dst_shape: &Shape, - src: &CudaSlice, - src_l: &Layout, - dev: &CudaDevice, - ) -> Result<()>; - - fn map( - &self, - dst: &mut S, - dst_s: &Shape, - src: &S, - src_l: &Layout, - d: &CudaDevice, - ) -> Result<()> { - match (dst, src) { - (S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d), - (S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d), - (S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d), - (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d), - (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), - (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), - (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), - _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, - } - } -} - -pub trait Map1Any { - fn f) -> S>( - &self, - src: &CudaSlice, - dev: &CudaDevice, - layout: &Layout, - wrap: W, - ) -> Result; - - fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { - let out = match s { - S::U8(s) => self.f(s, d, l, S::U8)?, - S::U32(s) => self.f(s, d, l, S::U32)?, - S::I64(s) => self.f(s, d, l, S::I64)?, - S::BF16(s) => self.f(s, d, l, S::BF16)?, - S::F16(s) => self.f(s, d, l, S::F16)?, - S::F32(s) => self.f(s, d, l, S::F32)?, - S::F64(s) => self.f(s, d, l, S::F64)?, - }; - Ok(out) - } -} - -pub trait Map2Any { - fn f( - &self, - src1: &CudaSlice, - layout1: &Layout, - src2: &CudaSlice, - layout2: &Layout, - dev: &CudaDevice, - ) -> Result; - - fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { - let out = match (s1, s2) { - (S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, - (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, - _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, - }; - Ok(out) - } -} struct Clone; impl Map1 for Clone { diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs new file mode 100644 index 00000000..8dd5be77 --- /dev/null +++ b/candle-core/src/cuda_backend/utils.rs @@ -0,0 +1,134 @@ +/// Helper functions to plug cuda kernels in candle. +use crate::{Layout, Result, Shape, WithDType}; +pub use cudarc; +use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits}; + +use super::{CudaDevice, CudaError, WrapErr}; + +pub type S = super::CudaStorageSlice; + +pub trait Map1 { + fn f( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + ) -> Result>; + + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { + let out = match s { + S::U8(s) => S::U8(self.f(s, d, l)?), + S::U32(s) => S::U32(self.f(s, d, l)?), + S::I64(s) => S::I64(self.f(s, d, l)?), + S::BF16(s) => S::BF16(self.f(s, d, l)?), + S::F16(s) => S::F16(self.f(s, d, l)?), + S::F32(s) => S::F32(self.f(s, d, l)?), + S::F64(s) => S::F64(self.f(s, d, l)?), + }; + Ok(out) + } +} + +pub trait Map2 { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + dev: &CudaDevice, + ) -> Result>; + + fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { + let out = match (s1, s2) { + (S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?), + (S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?), + (S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?), + (S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?), + (S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?), + (S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?), + (S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, + }; + Ok(out) + } +} + +pub trait Map2InPlace { + fn f( + &self, + dst: &mut CudaSlice, + dst_shape: &Shape, + src: &CudaSlice, + src_l: &Layout, + dev: &CudaDevice, + ) -> Result<()>; + + fn map( + &self, + dst: &mut S, + dst_s: &Shape, + src: &S, + src_l: &Layout, + d: &CudaDevice, + ) -> Result<()> { + match (dst, src) { + (S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d), + (S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d), + (S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d), + (S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d), + (S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, + } + } +} + +pub trait Map1Any { + fn f) -> S>( + &self, + src: &CudaSlice, + dev: &CudaDevice, + layout: &Layout, + wrap: W, + ) -> Result; + + fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result { + let out = match s { + S::U8(s) => self.f(s, d, l, S::U8)?, + S::U32(s) => self.f(s, d, l, S::U32)?, + S::I64(s) => self.f(s, d, l, S::I64)?, + S::BF16(s) => self.f(s, d, l, S::BF16)?, + S::F16(s) => self.f(s, d, l, S::F16)?, + S::F32(s) => self.f(s, d, l, S::F32)?, + S::F64(s) => self.f(s, d, l, S::F64)?, + }; + Ok(out) + } +} + +pub trait Map2Any { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + dev: &CudaDevice, + ) -> Result; + + fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result { + let out = match (s1, s2) { + (S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, + (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, + _ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?, + }; + Ok(out) + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 911e379f..862436ab 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -43,8 +43,6 @@ pub mod cpu; pub mod cpu_backend; #[cfg(feature = "cuda")] pub mod cuda_backend; -#[cfg(feature = "cudnn")] -pub mod cudnn; mod custom_op; mod device; pub mod display; @@ -73,6 +71,9 @@ pub mod test_utils; pub mod utils; mod variable; +#[cfg(feature = "cudnn")] +pub use cuda_backend::cudnn; + pub use cpu_backend::CpuStorage; pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; pub use device::{Device, DeviceLocation, NdArray}; diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs new file mode 100644 index 00000000..fdeca13f --- /dev/null +++ b/candle-core/src/metal_backend/device.rs @@ -0,0 +1,287 @@ +use crate::{DType, Result}; +use candle_metal_kernels::Kernels; +use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use std::collections::HashMap; +use std::ffi::c_void; +use std::path::Path; +use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard}; + +use super::MetalError; + +/// Unique identifier for cuda devices. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct DeviceId(usize); + +impl DeviceId { + pub(crate) fn new() -> Self { + // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 + use std::sync::atomic; + static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); + Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) + } +} + +type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec>>; +type AllocatedBuffers = Arc>; + +#[derive(Clone)] +pub struct MetalDevice { + /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than + /// the device itself. + pub(crate) id: DeviceId, + + /// Raw metal device: + pub(crate) device: metal::Device, + + /// Single command queue for the entire device. + pub(crate) command_queue: CommandQueue, + /// One command buffer at a time. + /// The scheduler works by allowing multiple + /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) + /// on a single command buffer. Using a single command buffer would be fastest on the GPU but + /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed + /// to start to work). + /// Despite what the documentation says, command buffers are NOT ordered. They are ordered + /// for their START time, but there's no guarantee that command buffer1 will finish before + /// command buffer2 starts (or there are metal bugs there) + pub(crate) command_buffer: Arc>, + /// Keeps track of the current amount of compute command encoders on the current + /// command buffer + /// Arc, RwLock because of the interior mutability. + pub(crate) command_buffer_index: Arc>, + /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) + pub(crate) compute_per_buffer: usize, + /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. + /// Heavily used by [`candle_metal_kernels`] + pub(crate) kernels: Arc, + /// Simple allocator struct. + /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. + /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting + /// (could be linked to FFI communication overhead). + /// + /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the + /// graph calculation, and only we the allocator kept a reference to it, therefore it's free + /// to be reused. However, in order for this to work, we need to guarantee the order of + /// operation, so that this buffer is not being used by another kernel at the same time. + /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. + /// + /// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers + /// (strong_count = 1). + pub(crate) buffers: AllocatedBuffers, + /// Seed for random number generation. + pub(crate) seed: Arc>, +} + +impl std::fmt::Debug for MetalDevice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MetalDevice({:?})", self.id) + } +} + +impl std::ops::Deref for MetalDevice { + type Target = metal::DeviceRef; + + fn deref(&self) -> &Self::Target { + &self.device + } +} + +impl MetalDevice { + pub fn id(&self) -> DeviceId { + self.id + } + + pub fn metal_device(&self) -> &metal::Device { + &self.device + } + + pub fn command_queue(&self) -> &CommandQueue { + &self.command_queue + } + + pub fn command_buffer(&self) -> Result { + let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; + let mut command_buffer = command_buffer_lock.to_owned(); + let mut index = self + .command_buffer_index + .try_write() + .map_err(MetalError::from)?; + if *index > self.compute_per_buffer { + command_buffer.commit(); + command_buffer = self.command_queue.new_command_buffer().to_owned(); + *command_buffer_lock = command_buffer.clone(); + *index = 0; + + self.drop_unused_buffers()?; + } + *index += 1; + Ok(command_buffer) + } + + pub fn wait_until_completed(&self) -> Result<()> { + let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; + match command_buffer.status() { + metal::MTLCommandBufferStatus::Committed + | metal::MTLCommandBufferStatus::Scheduled + | metal::MTLCommandBufferStatus::Completed => { + panic!("Already committed"); + } + _ => {} + } + command_buffer.commit(); + command_buffer.wait_until_completed(); + *command_buffer = self.command_queue.new_command_buffer().to_owned(); + + Ok(()) + } + + pub fn kernels(&self) -> &Kernels { + &self.kernels + } + + pub fn device(&self) -> &metal::Device { + &self.device + } + + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer data cannot be read on the CPU directly. + /// + /// [`name`] is only used to keep track of the resource origin in case of bugs + pub fn new_buffer( + &self, + element_count: usize, + dtype: DType, + name: &str, + ) -> Result> { + let size = (element_count * dtype.size_in_bytes()) as NSUInteger; + self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) + } + + /// Creates a new buffer (not necessarily zeroed). + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// This means the buffer can be read on the CPU but will require manual + /// synchronization when the CPU memory is modified + /// Used as a bridge to gather data back from the GPU + pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { + self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") + } + + /// Creates a new buffer from data. + /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) + /// + /// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes) + /// allocates the buffer and copies over the existing data before returning the MTLBuffer. + pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { + let size = core::mem::size_of_val(data) as NSUInteger; + let new_buffer = self.device.new_buffer_with_data( + data.as_ptr() as *const c_void, + size, + MTLResourceOptions::StorageModeManaged, + ); + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + let subbuffers = buffers + .entry((size, MTLResourceOptions::StorageModeManaged)) + .or_insert(vec![]); + + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + Ok(new_buffer) + } + + pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { + let buffer = self.allocate_buffer( + size_in_bytes as NSUInteger, + MTLResourceOptions::StorageModePrivate, + "allocate_zeros", + )?; + let command_buffer = self.command_buffer()?; + command_buffer.set_label("zeros"); + let blit = command_buffer.new_blit_command_encoder(); + blit.fill_buffer( + &buffer, + metal::NSRange { + location: 0, + length: buffer.length(), + }, + 0, + ); + blit.end_encoding(); + Ok(buffer) + } + + fn find_available_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + buffers: &RwLockWriteGuard, + ) -> Option> { + let mut best_buffer: Option<&Arc> = None; + let mut best_buffer_size: NSUInteger = NSUInteger::MAX; + for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { + if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { + for sub in subbuffers { + if Arc::strong_count(sub) == 1 { + best_buffer = Some(sub); + best_buffer_size = *buffer_size; + } + } + } + } + best_buffer.cloned() + } + + fn drop_unused_buffers(&self) -> Result<()> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + for subbuffers in buffers.values_mut() { + let newbuffers = subbuffers + .iter() + .filter(|s| Arc::strong_count(*s) > 1) + .map(Arc::clone) + .collect(); + *subbuffers = newbuffers; + } + Ok(()) + } + + /// The critical allocator algorithm + fn allocate_buffer( + &self, + size: NSUInteger, + option: MTLResourceOptions, + _name: &str, + ) -> Result> { + let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; + if let Some(b) = self.find_available_buffer(size, option, &buffers) { + // Cloning also ensures we increment the strong count + return Ok(b.clone()); + } + + let size = buf_size(size); + let subbuffers = buffers.entry((size, option)).or_insert(vec![]); + + let new_buffer = self.device.new_buffer(size as NSUInteger, option); + let new_buffer = Arc::new(new_buffer); + subbuffers.push(new_buffer.clone()); + + Ok(new_buffer) + } + + /// Create a metal GPU capture trace on [`path`]. + pub fn capture>(&self, path: P) -> Result<()> { + let capture = metal::CaptureManager::shared(); + let descriptor = metal::CaptureDescriptor::new(); + descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); + descriptor.set_capture_device(self); + descriptor.set_output_url(path); + + capture + .start_capture(&descriptor) + .map_err(MetalError::from)?; + Ok(()) + } +} + +fn buf_size(size: NSUInteger) -> NSUInteger { + (size - 1).next_power_of_two() as NSUInteger +} diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend/mod.rs similarity index 86% rename from candle-core/src/metal_backend.rs rename to candle-core/src/metal_backend/mod.rs index fed7db13..deb7a401 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -4,24 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels::CallConvTranspose2dCfg; use candle_metal_kernels::Kernels; -use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; +use metal::{Buffer, MTLResourceOptions, NSUInteger}; use std::collections::HashMap; use std::ffi::c_void; -use std::path::Path; -use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError}; +use std::sync::{Arc, Mutex, RwLock, TryLockError}; -/// Unique identifier for cuda devices. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct DeviceId(usize); - -impl DeviceId { - fn new() -> Self { - // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 - use std::sync::atomic; - static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); - Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) - } -} +mod device; +pub use device::{DeviceId, MetalDevice}; /// Simple way to catch lock error without /// depending on T @@ -49,13 +38,6 @@ pub enum MetalError { Message(String), #[error(transparent)] KernelError(#[from] candle_metal_kernels::MetalKernelError), - - #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] - MatMulNonContiguous { - lhs_stride: Vec, - rhs_stride: Vec, - mnk: (usize, usize, usize), - }, #[error("{0:?}")] LockError(LockError), #[error("{msg}, expected: {expected:?}, got: {got:?}")] @@ -72,267 +54,6 @@ impl From for MetalError { } } -type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec>>; -type AllocatedBuffers = Arc>; - -#[derive(Clone)] -pub struct MetalDevice { - /// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than - /// the device itself. - id: DeviceId, - - /// Raw metal device: - device: metal::Device, - - /// Single command queue for the entire device. - command_queue: CommandQueue, - /// One command buffer at a time. - /// The scheduler works by allowing multiple - /// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) - /// on a single command buffer. Using a single command buffer would be fastest on the GPU but - /// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed - /// to start to work). - /// Despite what the documentation says, command buffers are NOT ordered. They are ordered - /// for their START time, but there's no guarantee that command buffer1 will finish before - /// command buffer2 starts (or there are metal bugs there) - command_buffer: Arc>, - /// Keeps track of the current amount of compute command encoders on the current - /// command buffer - /// Arc, RwLock because of the interior mutability. - command_buffer_index: Arc>, - /// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc) - compute_per_buffer: usize, - /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. - /// Heavily used by [`candle_metal_kernels`] - kernels: Arc, - /// Simple allocator struct. - /// The buffers are stored in size buckets since ML tends to use similar shapes over and over. - /// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting - /// (could be linked to FFI communication overhead). - /// - /// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the - /// graph calculation, and only we the allocator kept a reference to it, therefore it's free - /// to be reused. However, in order for this to work, we need to guarantee the order of - /// operation, so that this buffer is not being used by another kernel at the same time. - /// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things. - /// - /// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers - /// (strong_count = 1). - buffers: AllocatedBuffers, - /// Seed for random number generation. - seed: Arc>, -} - -impl std::fmt::Debug for MetalDevice { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "MetalDevice({:?})", self.id) - } -} - -impl std::ops::Deref for MetalDevice { - type Target = metal::DeviceRef; - - fn deref(&self) -> &Self::Target { - &self.device - } -} - -impl MetalDevice { - pub fn id(&self) -> DeviceId { - self.id - } - - pub fn metal_device(&self) -> &metal::Device { - &self.device - } - - pub fn command_queue(&self) -> &CommandQueue { - &self.command_queue - } - - pub fn command_buffer(&self) -> Result { - let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?; - let mut command_buffer = command_buffer_lock.to_owned(); - let mut index = self - .command_buffer_index - .try_write() - .map_err(MetalError::from)?; - if *index > self.compute_per_buffer { - command_buffer.commit(); - command_buffer = self.command_queue.new_command_buffer().to_owned(); - *command_buffer_lock = command_buffer.clone(); - *index = 0; - - self.drop_unused_buffers()?; - } - *index += 1; - Ok(command_buffer) - } - - pub fn wait_until_completed(&self) -> Result<()> { - let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?; - match command_buffer.status() { - metal::MTLCommandBufferStatus::Committed - | metal::MTLCommandBufferStatus::Scheduled - | metal::MTLCommandBufferStatus::Completed => { - panic!("Already committed"); - } - _ => {} - } - command_buffer.commit(); - command_buffer.wait_until_completed(); - *command_buffer = self.command_queue.new_command_buffer().to_owned(); - - Ok(()) - } - - pub fn kernels(&self) -> &Kernels { - &self.kernels - } - - pub fn device(&self) -> &metal::Device { - &self.device - } - - /// Creates a new buffer (not necessarily zeroed). - /// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode) - /// This means the buffer data cannot be read on the CPU directly. - /// - /// [`name`] is only used to keep track of the resource origin in case of bugs - pub fn new_buffer( - &self, - element_count: usize, - dtype: DType, - name: &str, - ) -> Result> { - let size = (element_count * dtype.size_in_bytes()) as NSUInteger; - self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name) - } - - /// Creates a new buffer (not necessarily zeroed). - /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) - /// This means the buffer can be read on the CPU but will require manual - /// synchronization when the CPU memory is modified - /// Used as a bridge to gather data back from the GPU - pub fn new_buffer_managed(&self, size: NSUInteger) -> Result> { - self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed") - } - - /// Creates a new buffer from data. - /// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode) - /// - /// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes) - /// allocates the buffer and copies over the existing data before returning the MTLBuffer. - pub fn new_buffer_with_data(&self, data: &[T]) -> Result> { - let size = core::mem::size_of_val(data) as NSUInteger; - let new_buffer = self.device.new_buffer_with_data( - data.as_ptr() as *const c_void, - size, - MTLResourceOptions::StorageModeManaged, - ); - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; - let subbuffers = buffers - .entry((size, MTLResourceOptions::StorageModeManaged)) - .or_insert(vec![]); - - let new_buffer = Arc::new(new_buffer); - subbuffers.push(new_buffer.clone()); - Ok(new_buffer) - } - - pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result> { - let buffer = self.allocate_buffer( - size_in_bytes as NSUInteger, - MTLResourceOptions::StorageModePrivate, - "allocate_zeros", - )?; - let command_buffer = self.command_buffer()?; - command_buffer.set_label("zeros"); - let blit = command_buffer.new_blit_command_encoder(); - blit.fill_buffer( - &buffer, - metal::NSRange { - location: 0, - length: buffer.length(), - }, - 0, - ); - blit.end_encoding(); - Ok(buffer) - } - - fn find_available_buffer( - &self, - size: NSUInteger, - option: MTLResourceOptions, - buffers: &RwLockWriteGuard, - ) -> Option> { - let mut best_buffer: Option<&Arc> = None; - let mut best_buffer_size: NSUInteger = NSUInteger::MAX; - for ((buffer_size, buffer_option), subbuffers) in buffers.iter() { - if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option { - for sub in subbuffers { - if Arc::strong_count(sub) == 1 { - best_buffer = Some(sub); - best_buffer_size = *buffer_size; - } - } - } - } - best_buffer.cloned() - } - - fn drop_unused_buffers(&self) -> Result<()> { - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; - for subbuffers in buffers.values_mut() { - let newbuffers = subbuffers - .iter() - .filter(|s| Arc::strong_count(*s) > 1) - .map(Arc::clone) - .collect(); - *subbuffers = newbuffers; - } - Ok(()) - } - - /// The critical allocator algorithm - fn allocate_buffer( - &self, - size: NSUInteger, - option: MTLResourceOptions, - _name: &str, - ) -> Result> { - let mut buffers = self.buffers.try_write().map_err(MetalError::from)?; - if let Some(b) = self.find_available_buffer(size, option, &buffers) { - // Cloning also ensures we increment the strong count - return Ok(b.clone()); - } - - let size = buf_size(size); - let subbuffers = buffers.entry((size, option)).or_insert(vec![]); - - let new_buffer = self.device.new_buffer(size as NSUInteger, option); - let new_buffer = Arc::new(new_buffer); - subbuffers.push(new_buffer.clone()); - - Ok(new_buffer) - } - - /// Create a metal GPU capture trace on [`path`]. - pub fn capture>(&self, path: P) -> Result<()> { - let capture = metal::CaptureManager::shared(); - let descriptor = metal::CaptureDescriptor::new(); - descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); - descriptor.set_capture_device(self); - descriptor.set_output_url(path); - - capture - .start_capture(&descriptor) - .map_err(MetalError::from)?; - Ok(()) - } -} - #[derive(Debug, Clone)] pub struct MetalStorage { /// The actual buffer containing the data. @@ -2055,10 +1776,6 @@ impl BackendDevice for MetalDevice { } } -fn buf_size(size: NSUInteger) -> NSUInteger { - (size - 1).next_power_of_two() as NSUInteger -} - fn read_to_vec(buffer: &Buffer, n: usize) -> Vec { let ptr = buffer.contents() as *const T; assert!(!ptr.is_null());