mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Backend refactoring. (#1966)
* Backend refactoring. * Metal tweaks. * Move the cudnn module.
This commit is contained in:
@ -4,6 +4,11 @@ use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
|
|||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
mod utils;
|
||||||
|
pub use utils::{
|
||||||
|
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||||
|
};
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
const USE_IM2COL_CONV1D_TR: bool = true;
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
@ -24,102 +29,6 @@ pub enum CpuStorage {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct CpuDevice;
|
pub struct CpuDevice;
|
||||||
|
|
||||||
pub trait Map1 {
|
|
||||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
|
||||||
match vs {
|
|
||||||
CpuStorage::U8(vs) => Ok(CpuStorage::U8(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::I64(vs) => Ok(CpuStorage::I64(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
|
||||||
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map1Any {
|
|
||||||
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
|
||||||
&self,
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
wrap: W,
|
|
||||||
) -> Result<CpuStorage>;
|
|
||||||
|
|
||||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
|
||||||
match vs {
|
|
||||||
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
|
||||||
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
|
||||||
CpuStorage::I64(vs) => Ok(self.f(vs, layout, CpuStorage::I64)?),
|
|
||||||
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
|
||||||
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
|
||||||
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
|
||||||
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type C = CpuStorage;
|
|
||||||
pub trait Map2 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
|
||||||
|
|
||||||
fn map(
|
|
||||||
&self,
|
|
||||||
v1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
v2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<CpuStorage> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2U8 {
|
|
||||||
const OP: &'static str;
|
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
|
||||||
|
|
||||||
fn map(
|
|
||||||
&self,
|
|
||||||
v1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
v2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<CpuStorage> {
|
|
||||||
match (v1, v2) {
|
|
||||||
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
|
||||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
|
||||||
lhs: v1.dtype(),
|
|
||||||
rhs: v2.dtype(),
|
|
||||||
op: Self::OP,
|
|
||||||
}
|
|
||||||
.bt()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Cmp(CmpOp);
|
struct Cmp(CmpOp);
|
||||||
impl Map2U8 for Cmp {
|
impl Map2U8 for Cmp {
|
||||||
const OP: &'static str = "cmp";
|
const OP: &'static str = "cmp";
|
||||||
@ -366,275 +275,6 @@ impl<'a> Map1 for ReduceSum<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
|
||||||
[start_offset..start_offset + len]
|
|
||||||
.iter()
|
|
||||||
.map(|&v| f(v))
|
|
||||||
.collect(),
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for index in block_start_index {
|
|
||||||
for offset in 0..block_len {
|
|
||||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
|
||||||
vs: &[T],
|
|
||||||
layout: &Layout,
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match layout.strided_blocks() {
|
|
||||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(len);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(len) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
crate::StridedBlocks::MultipleBlocks {
|
|
||||||
block_start_index,
|
|
||||||
block_len,
|
|
||||||
} => {
|
|
||||||
let el_count = layout.shape().elem_count();
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
|
||||||
if block_len == 1 {
|
|
||||||
let mut result = Vec::with_capacity(el_count);
|
|
||||||
for index in block_start_index {
|
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
|
||||||
result.push(f(*v))
|
|
||||||
}
|
|
||||||
result
|
|
||||||
} else {
|
|
||||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
|
||||||
let mut dst_index = 0;
|
|
||||||
for src_index in block_start_index {
|
|
||||||
let vs = &vs[src_index..src_index + block_len];
|
|
||||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
|
||||||
f_vec(vs, ys);
|
|
||||||
dst_index += block_len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This function maps over two strided index sequences.
|
|
||||||
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
) -> Vec<U> {
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.zip(rhs[o_r1..o_r2].iter())
|
|
||||||
.map(|(&l, &r)| f(l, r))
|
|
||||||
.collect(),
|
|
||||||
(Some((o_l1, o_l2)), None) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match rhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
lhs[o_l1..o_l2]
|
|
||||||
.iter()
|
|
||||||
.map(|&l| {
|
|
||||||
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(l, *r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(None, Some((o_r1, o_r2))) => {
|
|
||||||
// TODO: Maybe we want to avoid going through the layout twice.
|
|
||||||
match lhs_l.offsets_b() {
|
|
||||||
Some(ob) => {
|
|
||||||
let mut i_in_block = 0;
|
|
||||||
let mut i_right_broadcast = 0;
|
|
||||||
rhs[o_r1..o_r2]
|
|
||||||
.iter()
|
|
||||||
.map(|&r| {
|
|
||||||
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
|
||||||
i_right_broadcast += 1;
|
|
||||||
if i_right_broadcast >= ob.right_broadcast {
|
|
||||||
i_in_block += 1;
|
|
||||||
i_right_broadcast = 0;
|
|
||||||
}
|
|
||||||
if i_in_block >= ob.len {
|
|
||||||
i_in_block = 0
|
|
||||||
}
|
|
||||||
f(*l, r)
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Similar to binary_map but with vectorized variants.
|
|
||||||
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
lhs: &[T],
|
|
||||||
rhs: &[T],
|
|
||||||
mut f: F,
|
|
||||||
mut f_vec: FV,
|
|
||||||
) -> Vec<T> {
|
|
||||||
let el_count = lhs_l.shape().elem_count();
|
|
||||||
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
|
||||||
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
&lhs[src_i..src_i + ob.len],
|
|
||||||
rhs,
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &r) in rhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(*v, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
|
||||||
Some(ob) if ob.right_broadcast == 1 => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
|
||||||
let ys_to_set = ys.spare_capacity_mut();
|
|
||||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
|
||||||
let mut dst_i = 0;
|
|
||||||
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
|
||||||
f_vec(
|
|
||||||
lhs,
|
|
||||||
&rhs[src_i..src_i + ob.len],
|
|
||||||
&mut ys_to_set[dst_i..dst_i + ob.len],
|
|
||||||
);
|
|
||||||
dst_i += ob.len;
|
|
||||||
}
|
|
||||||
// SAFETY: values are all set by f_vec.
|
|
||||||
unsafe { ys.set_len(el_count) };
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
Some(ob) => {
|
|
||||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
|
||||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
|
||||||
for idx_l in 0..ob.left_broadcast {
|
|
||||||
let start = idx_l * ob.len * ob.right_broadcast;
|
|
||||||
for (i, &l) in lhs.iter().enumerate() {
|
|
||||||
let start = start + i * ob.right_broadcast;
|
|
||||||
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
|
||||||
*v = f(l, *v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ys
|
|
||||||
}
|
|
||||||
None => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
},
|
|
||||||
_ => lhs_l
|
|
||||||
.strided_index()
|
|
||||||
.zip(rhs_l.strided_index())
|
|
||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
|
||||||
.collect(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Affine(f64, f64);
|
struct Affine(f64, f64);
|
||||||
|
|
||||||
impl Map1 for Affine {
|
impl Map1 for Affine {
|
350
candle-core/src/cpu_backend/utils.rs
Normal file
350
candle-core/src/cpu_backend/utils.rs
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
/// Helper functions to write CPU kernels.
|
||||||
|
use crate::backend::BackendStorage;
|
||||||
|
use crate::{Error, Layout, Result, WithDType};
|
||||||
|
|
||||||
|
type C = super::CpuStorage;
|
||||||
|
pub trait Map1 {
|
||||||
|
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||||
|
match vs {
|
||||||
|
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
|
||||||
|
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
|
||||||
|
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
|
||||||
|
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
|
||||||
|
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
|
||||||
|
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
|
||||||
|
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map1Any {
|
||||||
|
fn f<T: WithDType, W: Fn(Vec<T>) -> C>(&self, vs: &[T], layout: &Layout, wrap: W) -> Result<C>;
|
||||||
|
|
||||||
|
fn map(&self, vs: &C, layout: &Layout) -> Result<C> {
|
||||||
|
match vs {
|
||||||
|
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
|
||||||
|
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
|
||||||
|
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
|
||||||
|
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
|
||||||
|
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
|
||||||
|
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
|
||||||
|
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<T>>;
|
||||||
|
|
||||||
|
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::I64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2U8 {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||||
|
|
||||||
|
fn map(&self, v1: &C, l1: &Layout, v2: &C, l2: &Layout) -> Result<C> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::U32(v1), C::U32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::I64(v1), C::I64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn binary_map<T: Copy, U: Copy, F: FnMut(T, T) -> U>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.zip(rhs[o_r1..o_r2].iter())
|
||||||
|
.map(|(&l, &r)| f(l, r))
|
||||||
|
.collect(),
|
||||||
|
(Some((o_l1, o_l2)), None) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match rhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
lhs[o_l1..o_l2]
|
||||||
|
.iter()
|
||||||
|
.map(|&l| {
|
||||||
|
let r = unsafe { rhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(l, *r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(None, Some((o_r1, o_r2))) => {
|
||||||
|
// TODO: Maybe we want to avoid going through the layout twice.
|
||||||
|
match lhs_l.offsets_b() {
|
||||||
|
Some(ob) => {
|
||||||
|
let mut i_in_block = 0;
|
||||||
|
let mut i_right_broadcast = 0;
|
||||||
|
rhs[o_r1..o_r2]
|
||||||
|
.iter()
|
||||||
|
.map(|&r| {
|
||||||
|
let l = unsafe { lhs.get_unchecked(i_in_block + ob.start) };
|
||||||
|
i_right_broadcast += 1;
|
||||||
|
if i_right_broadcast >= ob.right_broadcast {
|
||||||
|
i_in_block += 1;
|
||||||
|
i_right_broadcast = 0;
|
||||||
|
}
|
||||||
|
if i_in_block >= ob.len {
|
||||||
|
i_in_block = 0
|
||||||
|
}
|
||||||
|
f(*l, r)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Similar to binary_map but with vectorized variants.
|
||||||
|
pub fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>(
|
||||||
|
lhs_l: &Layout,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<T> {
|
||||||
|
let el_count = lhs_l.shape().elem_count();
|
||||||
|
match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) {
|
||||||
|
(Some((o_l1, o_l2)), Some((o_r1, o_r2))) => {
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
|
f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_l1..o_l2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
&lhs[src_i..src_i + ob.len],
|
||||||
|
rhs,
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &r) in rhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(*v, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||||
|
Some(ob) if ob.right_broadcast == 1 => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||||
|
let mut dst_i = 0;
|
||||||
|
for src_i in (o_r1..o_r2).step_by(ob.len) {
|
||||||
|
f_vec(
|
||||||
|
lhs,
|
||||||
|
&rhs[src_i..src_i + ob.len],
|
||||||
|
&mut ys_to_set[dst_i..dst_i + ob.len],
|
||||||
|
);
|
||||||
|
dst_i += ob.len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
Some(ob) => {
|
||||||
|
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||||
|
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||||
|
for idx_l in 0..ob.left_broadcast {
|
||||||
|
let start = idx_l * ob.len * ob.right_broadcast;
|
||||||
|
for (i, &l) in lhs.iter().enumerate() {
|
||||||
|
let start = start + i * ob.right_broadcast;
|
||||||
|
for v in ys[start..start + ob.right_broadcast].iter_mut() {
|
||||||
|
*v = f(l, *v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
None => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
_ => lhs_l
|
||||||
|
.strided_index()
|
||||||
|
.zip(rhs_l.strided_index())
|
||||||
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||||
|
[start_offset..start_offset + len]
|
||||||
|
.iter()
|
||||||
|
.map(|&v| f(v))
|
||||||
|
.collect(),
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let mut result = Vec::with_capacity(layout.shape().elem_count());
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for index in block_start_index {
|
||||||
|
for offset in 0..block_len {
|
||||||
|
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||||
|
vs: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
mut f: F,
|
||||||
|
mut f_vec: FV,
|
||||||
|
) -> Vec<U> {
|
||||||
|
match layout.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(len);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||||
|
f_vec(&vs[start_offset..start_offset + len], ys_to_set);
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(len) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
|
if block_len == 1 {
|
||||||
|
let mut result = Vec::with_capacity(el_count);
|
||||||
|
for index in block_start_index {
|
||||||
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
|
result.push(f(*v))
|
||||||
|
}
|
||||||
|
result
|
||||||
|
} else {
|
||||||
|
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||||
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||||
|
let mut dst_index = 0;
|
||||||
|
for src_index in block_start_index {
|
||||||
|
let vs = &vs[src_index..src_index + block_len];
|
||||||
|
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||||
|
f_vec(vs, ys);
|
||||||
|
dst_index += block_len;
|
||||||
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
410
candle-core/src/cuda_backend/device.rs
Normal file
410
candle-core/src/cuda_backend/device.rs
Normal file
@ -0,0 +1,410 @@
|
|||||||
|
use crate::backend::BackendDevice;
|
||||||
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
|
pub use candle_kernels as kernels;
|
||||||
|
pub use cudarc;
|
||||||
|
use cudarc::driver::{CudaFunction, LaunchAsync, LaunchConfig};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use super::{CudaError, CudaStorage, CudaStorageSlice, WrapErr};
|
||||||
|
|
||||||
|
/// Unique identifier for cuda devices.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub struct DeviceId(usize);
|
||||||
|
|
||||||
|
impl DeviceId {
|
||||||
|
fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CudaRng(cudarc::curand::CudaRng);
|
||||||
|
unsafe impl Send for CudaRng {}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct CudaDevice {
|
||||||
|
id: DeviceId,
|
||||||
|
device: Arc<cudarc::driver::CudaDevice>,
|
||||||
|
pub(crate) blas: Arc<cudarc::cublas::CudaBlas>,
|
||||||
|
curand: Arc<Mutex<CudaRng>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for CudaDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "CudaDevice({:?})", self.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for CudaDevice {
|
||||||
|
type Target = Arc<cudarc::driver::CudaDevice>;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CudaDevice {
|
||||||
|
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
||||||
|
self.device.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn id(&self) -> DeviceId {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
||||||
|
let params = (&data, v as u8, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
||||||
|
let params = (&data, v as u32, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
||||||
|
let params = (&data, v as i64, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
||||||
|
let params = (&data, bf16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
||||||
|
let params = (&data, f16::from_f64(v), elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
||||||
|
let params = (&data, v as f32, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
// SAFETY: Set later by running the fill kernel.
|
||||||
|
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
|
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
||||||
|
let params = (&data, v, elem_count);
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
||||||
|
if !self.has_func(module_name, module_name) {
|
||||||
|
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
||||||
|
// done once per kernel name.
|
||||||
|
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
||||||
|
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
||||||
|
.map_err(|cuda| CudaError::Load {
|
||||||
|
cuda,
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})
|
||||||
|
.w()?;
|
||||||
|
}
|
||||||
|
self.get_func(module_name, module_name)
|
||||||
|
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
||||||
|
// able to only build the error value if needed.
|
||||||
|
.ok_or(CudaError::MissingKernel {
|
||||||
|
module_name: module_name.to_string(),
|
||||||
|
})
|
||||||
|
.w()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BackendDevice for CudaDevice {
|
||||||
|
type Storage = CudaStorage;
|
||||||
|
|
||||||
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
|
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
||||||
|
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
||||||
|
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
||||||
|
Ok(Self {
|
||||||
|
id: DeviceId::new(),
|
||||||
|
device,
|
||||||
|
blas: Arc::new(blas),
|
||||||
|
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||||
|
// We do not call set_seed but instead create a new curand object. This ensures that the
|
||||||
|
// state will be identical and the same random numbers will be generated.
|
||||||
|
let mut curand = self.curand.lock().unwrap();
|
||||||
|
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
crate::DeviceLocation::Cuda {
|
||||||
|
gpu_id: self.device.ordinal(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn same_device(&self, rhs: &Self) -> bool {
|
||||||
|
self.id == rhs.id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
let slice = match dtype {
|
||||||
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||||
|
// cudarc changes.
|
||||||
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_uniform",
|
||||||
|
})
|
||||||
|
.w()?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
||||||
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
||||||
|
curand.0.fill_with_uniform(&mut data).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let slice = if lo == 0. && up == 1.0 {
|
||||||
|
slice
|
||||||
|
} else {
|
||||||
|
use super::utils::Map1;
|
||||||
|
let layout = Layout::contiguous(shape);
|
||||||
|
super::Affine(up - lo, lo).map(&slice, self, &layout)?
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
||||||
|
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
||||||
|
// cudarc changes.
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let curand = self.curand.lock().unwrap();
|
||||||
|
// curand can only generate an odd number of values.
|
||||||
|
// https://github.com/huggingface/candle/issues/734
|
||||||
|
let elem_count_round = if elem_count % 2 == 1 {
|
||||||
|
elem_count + 1
|
||||||
|
} else {
|
||||||
|
elem_count
|
||||||
|
};
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
||||||
|
Err(CudaError::UnsupportedDtype {
|
||||||
|
dtype,
|
||||||
|
op: "rand_normal",
|
||||||
|
})
|
||||||
|
.w()?
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
||||||
|
curand
|
||||||
|
.0
|
||||||
|
.fill_with_normal(&mut data, mean as f32, std as f32)
|
||||||
|
.w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
||||||
|
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||||
|
self.const_impl(1., shape, dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let slice = match dtype {
|
||||||
|
DType::U8 => {
|
||||||
|
let data = self.alloc::<u8>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
DType::U32 => {
|
||||||
|
let data = self.alloc::<u32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
DType::I64 => {
|
||||||
|
let data = self.alloc::<i64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
DType::BF16 => {
|
||||||
|
let data = self.alloc::<bf16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
DType::F16 => {
|
||||||
|
let data = self.alloc::<f16>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
DType::F32 => {
|
||||||
|
let data = self.alloc::<f32>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
DType::F64 => {
|
||||||
|
let data = self.alloc::<f64>(elem_count).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
||||||
|
let slice = match storage {
|
||||||
|
CpuStorage::U8(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorage::U32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::I64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorage::BF16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F16(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let data = self.htod_sync_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
||||||
|
let slice = match storage {
|
||||||
|
CpuStorage::U8(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U8(data)
|
||||||
|
}
|
||||||
|
CpuStorage::U32(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::U32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::I64(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::I64(data)
|
||||||
|
}
|
||||||
|
CpuStorage::BF16(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::BF16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F16(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F16(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F32(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F32(data)
|
||||||
|
}
|
||||||
|
CpuStorage::F64(storage) => {
|
||||||
|
let data = self.htod_copy(storage).w()?;
|
||||||
|
CudaStorageSlice::F64(data)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: self.clone(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -5,11 +5,17 @@ pub use candle_kernels as kernels;
|
|||||||
pub use cudarc;
|
pub use cudarc;
|
||||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
use cudarc::driver::{
|
use cudarc::driver::{
|
||||||
CudaFunction, CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig,
|
CudaSlice, DevicePtr, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||||
ValidAsZeroBits,
|
|
||||||
};
|
};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
mod device;
|
||||||
|
pub use device::{CudaDevice, DeviceId};
|
||||||
|
mod utils;
|
||||||
|
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||||
|
|
||||||
|
#[cfg(feature = "cudnn")]
|
||||||
|
pub mod cudnn;
|
||||||
|
|
||||||
enum SlicePtrOrNull<T> {
|
enum SlicePtrOrNull<T> {
|
||||||
Ptr(CudaSlice<T>),
|
Ptr(CudaSlice<T>),
|
||||||
@ -87,44 +93,6 @@ impl From<CudaError> for crate::Error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unique identifier for cuda devices.
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
|
||||||
pub struct DeviceId(usize);
|
|
||||||
|
|
||||||
impl DeviceId {
|
|
||||||
fn new() -> Self {
|
|
||||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
|
||||||
use std::sync::atomic;
|
|
||||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
|
||||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CudaRng(cudarc::curand::CudaRng);
|
|
||||||
unsafe impl Send for CudaRng {}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct CudaDevice {
|
|
||||||
id: DeviceId,
|
|
||||||
device: Arc<cudarc::driver::CudaDevice>,
|
|
||||||
blas: Arc<cudarc::cublas::CudaBlas>,
|
|
||||||
curand: Arc<Mutex<CudaRng>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for CudaDevice {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "CudaDevice({:?})", self.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::ops::Deref for CudaDevice {
|
|
||||||
type Target = Arc<cudarc::driver::CudaDevice>;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait WrapErr<O> {
|
pub trait WrapErr<O> {
|
||||||
fn w(self) -> std::result::Result<O, crate::Error>;
|
fn w(self) -> std::result::Result<O, crate::Error>;
|
||||||
}
|
}
|
||||||
@ -135,368 +103,6 @@ impl<O, E: Into<CudaError>> WrapErr<O> for std::result::Result<O, E> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CudaDevice {
|
|
||||||
pub fn cuda_device(&self) -> Arc<cudarc::driver::CudaDevice> {
|
|
||||||
self.device.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn id(&self) -> DeviceId {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<u8>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_u8", kernels::FILL)?;
|
|
||||||
let params = (&data, v as u8, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<u32>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_u32", kernels::FILL)?;
|
|
||||||
let params = (&data, v as u32, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<i64>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_i64", kernels::FILL)?;
|
|
||||||
let params = (&data, v as i64, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<bf16>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_bf16", kernels::FILL)?;
|
|
||||||
let params = (&data, bf16::from_f64(v), elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f16>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_f16", kernels::FILL)?;
|
|
||||||
let params = (&data, f16::from_f64(v), elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_f32", kernels::FILL)?;
|
|
||||||
let params = (&data, v as f32, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
|
||||||
let func = self.get_or_load_func("fill_f64", kernels::FILL)?;
|
|
||||||
let params = (&data, v, elem_count);
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_or_load_func(&self, module_name: &str, ptx: &'static str) -> Result<CudaFunction> {
|
|
||||||
if !self.has_func(module_name, module_name) {
|
|
||||||
// Leaking the string here is a bit sad but we need a &'static str and this is only
|
|
||||||
// done once per kernel name.
|
|
||||||
let static_module_name = Box::leak(module_name.to_string().into_boxed_str());
|
|
||||||
self.load_ptx(ptx.into(), module_name, &[static_module_name])
|
|
||||||
.map_err(|cuda| CudaError::Load {
|
|
||||||
cuda,
|
|
||||||
module_name: module_name.to_string(),
|
|
||||||
})
|
|
||||||
.w()?;
|
|
||||||
}
|
|
||||||
self.get_func(module_name, module_name)
|
|
||||||
// Clippy recommends this `ok_or` rather than `ok_or_else` so hopefully the compiler is
|
|
||||||
// able to only build the error value if needed.
|
|
||||||
.ok_or(CudaError::MissingKernel {
|
|
||||||
module_name: module_name.to_string(),
|
|
||||||
})
|
|
||||||
.w()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BackendDevice for CudaDevice {
|
|
||||||
type Storage = CudaStorage;
|
|
||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
|
||||||
let device = cudarc::driver::CudaDevice::new(ordinal).w()?;
|
|
||||||
let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?;
|
|
||||||
let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?;
|
|
||||||
Ok(Self {
|
|
||||||
id: DeviceId::new(),
|
|
||||||
device,
|
|
||||||
blas: Arc::new(blas),
|
|
||||||
curand: Arc::new(Mutex::new(CudaRng(curand))),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
|
||||||
// We do not call set_seed but instead create a new curand object. This ensures that the
|
|
||||||
// state will be identical and the same random numbers will be generated.
|
|
||||||
let mut curand = self.curand.lock().unwrap();
|
|
||||||
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
|
||||||
crate::DeviceLocation::Cuda {
|
|
||||||
gpu_id: self.device.ordinal(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn same_device(&self, rhs: &Self) -> bool {
|
|
||||||
self.id == rhs.id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
let data = self.alloc_zeros::<u8>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
let data = self.alloc_zeros::<u32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
let data = self.alloc_zeros::<i64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
let data = self.alloc_zeros::<bf16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let data = self.alloc_zeros::<f16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let data = self.alloc_zeros::<f32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let data = self.alloc_zeros::<f64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, lo: f64, up: f64) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let curand = self.curand.lock().unwrap();
|
|
||||||
let slice = match dtype {
|
|
||||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
|
||||||
// cudarc changes.
|
|
||||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
|
||||||
Err(CudaError::UnsupportedDtype {
|
|
||||||
dtype,
|
|
||||||
op: "rand_uniform",
|
|
||||||
})
|
|
||||||
.w()?
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count) }.w()?;
|
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count) }.w()?;
|
|
||||||
curand.0.fill_with_uniform(&mut data).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let slice = if lo == 0. && up == 1.0 {
|
|
||||||
slice
|
|
||||||
} else {
|
|
||||||
let layout = Layout::contiguous(shape);
|
|
||||||
Affine(up - lo, lo).map(&slice, self, &layout)?
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, std: f64) -> Result<CudaStorage> {
|
|
||||||
// TODO: Add support for F16 and BF16 though this is likely to require some upstream
|
|
||||||
// cudarc changes.
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let curand = self.curand.lock().unwrap();
|
|
||||||
// curand can only generate an odd number of values.
|
|
||||||
// https://github.com/huggingface/candle/issues/734
|
|
||||||
let elem_count_round = if elem_count % 2 == 1 {
|
|
||||||
elem_count + 1
|
|
||||||
} else {
|
|
||||||
elem_count
|
|
||||||
};
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => {
|
|
||||||
Err(CudaError::UnsupportedDtype {
|
|
||||||
dtype,
|
|
||||||
op: "rand_normal",
|
|
||||||
})
|
|
||||||
.w()?
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f32>(elem_count_round) }.w()?;
|
|
||||||
curand
|
|
||||||
.0
|
|
||||||
.fill_with_normal(&mut data, mean as f32, std as f32)
|
|
||||||
.w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let mut data = unsafe { self.alloc::<f64>(elem_count_round) }.w()?;
|
|
||||||
curand.0.fill_with_normal(&mut data, mean, std).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
self.const_impl(1., shape, dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
let data = self.alloc::<u8>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
let data = self.alloc::<u32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
let data = self.alloc::<i64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
let data = self.alloc::<bf16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
let data = self.alloc::<f16>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
let data = self.alloc::<f32>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
let data = self.alloc::<f64>(elem_count).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<CudaStorage> {
|
|
||||||
let slice = match storage {
|
|
||||||
CpuStorage::U8(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
CpuStorage::U32(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::I64(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
CpuStorage::BF16(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F16(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F32(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F64(storage) => {
|
|
||||||
let data = self.htod_sync_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result<CudaStorage> {
|
|
||||||
let slice = match storage {
|
|
||||||
CpuStorage::U8(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
CpuStorage::U32(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::I64(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
CpuStorage::BF16(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F16(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F32(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
CpuStorage::F64(storage) => {
|
|
||||||
let data = self.htod_copy(storage).w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum CudaStorageSlice {
|
pub enum CudaStorageSlice {
|
||||||
U8(CudaSlice<u8>),
|
U8(CudaSlice<u8>),
|
||||||
@ -507,133 +113,6 @@ pub enum CudaStorageSlice {
|
|||||||
F32(CudaSlice<f32>),
|
F32(CudaSlice<f32>),
|
||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
type S = CudaStorageSlice;
|
|
||||||
|
|
||||||
pub trait Map1 {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
layout: &Layout,
|
|
||||||
) -> Result<CudaSlice<T>>;
|
|
||||||
|
|
||||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
|
||||||
let out = match s {
|
|
||||||
S::U8(s) => S::U8(self.f(s, d, l)?),
|
|
||||||
S::U32(s) => S::U32(self.f(s, d, l)?),
|
|
||||||
S::I64(s) => S::I64(self.f(s, d, l)?),
|
|
||||||
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
|
||||||
S::F16(s) => S::F16(self.f(s, d, l)?),
|
|
||||||
S::F32(s) => S::F32(self.f(s, d, l)?),
|
|
||||||
S::F64(s) => S::F64(self.f(s, d, l)?),
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2 {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
src1: &CudaSlice<T>,
|
|
||||||
layout1: &Layout,
|
|
||||||
src2: &CudaSlice<T>,
|
|
||||||
layout2: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaSlice<T>>;
|
|
||||||
|
|
||||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
|
||||||
let out = match (s1, s2) {
|
|
||||||
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2InPlace {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
dst: &mut CudaSlice<T>,
|
|
||||||
dst_shape: &Shape,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
src_l: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<()>;
|
|
||||||
|
|
||||||
fn map(
|
|
||||||
&self,
|
|
||||||
dst: &mut S,
|
|
||||||
dst_s: &Shape,
|
|
||||||
src: &S,
|
|
||||||
src_l: &Layout,
|
|
||||||
d: &CudaDevice,
|
|
||||||
) -> Result<()> {
|
|
||||||
match (dst, src) {
|
|
||||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map1Any {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
|
||||||
&self,
|
|
||||||
src: &CudaSlice<T>,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
layout: &Layout,
|
|
||||||
wrap: W,
|
|
||||||
) -> Result<S>;
|
|
||||||
|
|
||||||
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
|
||||||
let out = match s {
|
|
||||||
S::U8(s) => self.f(s, d, l, S::U8)?,
|
|
||||||
S::U32(s) => self.f(s, d, l, S::U32)?,
|
|
||||||
S::I64(s) => self.f(s, d, l, S::I64)?,
|
|
||||||
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
|
||||||
S::F16(s) => self.f(s, d, l, S::F16)?,
|
|
||||||
S::F32(s) => self.f(s, d, l, S::F32)?,
|
|
||||||
S::F64(s) => self.f(s, d, l, S::F64)?,
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait Map2Any {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
src1: &CudaSlice<T>,
|
|
||||||
layout1: &Layout,
|
|
||||||
src2: &CudaSlice<T>,
|
|
||||||
layout2: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<S>;
|
|
||||||
|
|
||||||
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
|
||||||
let out = match (s1, s2) {
|
|
||||||
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
|
||||||
};
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Clone;
|
struct Clone;
|
||||||
impl Map1 for Clone {
|
impl Map1 for Clone {
|
134
candle-core/src/cuda_backend/utils.rs
Normal file
134
candle-core/src/cuda_backend/utils.rs
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
/// Helper functions to plug cuda kernels in candle.
|
||||||
|
use crate::{Layout, Result, Shape, WithDType};
|
||||||
|
pub use cudarc;
|
||||||
|
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||||
|
|
||||||
|
use super::{CudaDevice, CudaError, WrapErr};
|
||||||
|
|
||||||
|
pub type S = super::CudaStorageSlice;
|
||||||
|
|
||||||
|
pub trait Map1 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
|
let out = match s {
|
||||||
|
S::U8(s) => S::U8(self.f(s, d, l)?),
|
||||||
|
S::U32(s) => S::U32(self.f(s, d, l)?),
|
||||||
|
S::I64(s) => S::I64(self.f(s, d, l)?),
|
||||||
|
S::BF16(s) => S::BF16(self.f(s, d, l)?),
|
||||||
|
S::F16(s) => S::F16(self.f(s, d, l)?),
|
||||||
|
S::F32(s) => S::F32(self.f(s, d, l)?),
|
||||||
|
S::F64(s) => S::F64(self.f(s, d, l)?),
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2 {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U8(s1), S::U8(s2)) => S::U8(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::U32(s1), S::U32(s2)) => S::U32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::I64(s1), S::I64(s2)) => S::I64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => S::BF16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F16(s1), S::F16(s2)) => S::F16(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F32(s1), S::F32(s2)) => S::F32(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
(S::F64(s1), S::F64(s2)) => S::F64(self.f(s1, l1, s2, l2, d)?),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2InPlace {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
dst_shape: &Shape,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
src_l: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()>;
|
||||||
|
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
dst: &mut S,
|
||||||
|
dst_s: &Shape,
|
||||||
|
src: &S,
|
||||||
|
src_l: &Layout,
|
||||||
|
d: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
match (dst, src) {
|
||||||
|
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map1Any {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits, W: Fn(CudaSlice<T>) -> S>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
wrap: W,
|
||||||
|
) -> Result<S>;
|
||||||
|
|
||||||
|
fn map(&self, s: &S, d: &CudaDevice, l: &Layout) -> Result<S> {
|
||||||
|
let out = match s {
|
||||||
|
S::U8(s) => self.f(s, d, l, S::U8)?,
|
||||||
|
S::U32(s) => self.f(s, d, l, S::U32)?,
|
||||||
|
S::I64(s) => self.f(s, d, l, S::I64)?,
|
||||||
|
S::BF16(s) => self.f(s, d, l, S::BF16)?,
|
||||||
|
S::F16(s) => self.f(s, d, l, S::F16)?,
|
||||||
|
S::F32(s) => self.f(s, d, l, S::F32)?,
|
||||||
|
S::F64(s) => self.f(s, d, l, S::F64)?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Map2Any {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<S>;
|
||||||
|
|
||||||
|
fn map(&self, s1: &S, l1: &Layout, s2: &S, l2: &Layout, d: &CudaDevice) -> Result<S> {
|
||||||
|
let out = match (s1, s2) {
|
||||||
|
(S::U8(s1), S::U8(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::U32(s1), S::U32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::I64(s1), S::I64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::BF16(s1), S::BF16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
(S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?,
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op")).w()?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
@ -43,8 +43,6 @@ pub mod cpu;
|
|||||||
pub mod cpu_backend;
|
pub mod cpu_backend;
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
pub mod cuda_backend;
|
pub mod cuda_backend;
|
||||||
#[cfg(feature = "cudnn")]
|
|
||||||
pub mod cudnn;
|
|
||||||
mod custom_op;
|
mod custom_op;
|
||||||
mod device;
|
mod device;
|
||||||
pub mod display;
|
pub mod display;
|
||||||
@ -73,6 +71,9 @@ pub mod test_utils;
|
|||||||
pub mod utils;
|
pub mod utils;
|
||||||
mod variable;
|
mod variable;
|
||||||
|
|
||||||
|
#[cfg(feature = "cudnn")]
|
||||||
|
pub use cuda_backend::cudnn;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::CpuStorage;
|
||||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
|
287
candle-core/src/metal_backend/device.rs
Normal file
287
candle-core/src/metal_backend/device.rs
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
use crate::{DType, Result};
|
||||||
|
use candle_metal_kernels::Kernels;
|
||||||
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::ffi::c_void;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard};
|
||||||
|
|
||||||
|
use super::MetalError;
|
||||||
|
|
||||||
|
/// Unique identifier for cuda devices.
|
||||||
|
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||||
|
pub struct DeviceId(usize);
|
||||||
|
|
||||||
|
impl DeviceId {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
||||||
|
use std::sync::atomic;
|
||||||
|
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
||||||
|
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||||
|
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MetalDevice {
|
||||||
|
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
||||||
|
/// the device itself.
|
||||||
|
pub(crate) id: DeviceId,
|
||||||
|
|
||||||
|
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||||
|
pub(crate) device: metal::Device,
|
||||||
|
|
||||||
|
/// Single command queue for the entire device.
|
||||||
|
pub(crate) command_queue: CommandQueue,
|
||||||
|
/// One command buffer at a time.
|
||||||
|
/// The scheduler works by allowing multiple
|
||||||
|
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||||
|
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
||||||
|
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
||||||
|
/// to start to work).
|
||||||
|
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||||
|
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||||
|
/// command buffer2 starts (or there are metal bugs there)
|
||||||
|
pub(crate) command_buffer: Arc<RwLock<CommandBuffer>>,
|
||||||
|
/// Keeps track of the current amount of compute command encoders on the current
|
||||||
|
/// command buffer
|
||||||
|
/// Arc, RwLock because of the interior mutability.
|
||||||
|
pub(crate) command_buffer_index: Arc<RwLock<usize>>,
|
||||||
|
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
||||||
|
pub(crate) compute_per_buffer: usize,
|
||||||
|
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||||
|
/// Heavily used by [`candle_metal_kernels`]
|
||||||
|
pub(crate) kernels: Arc<Kernels>,
|
||||||
|
/// Simple allocator struct.
|
||||||
|
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||||
|
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||||
|
/// (could be linked to FFI communication overhead).
|
||||||
|
///
|
||||||
|
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
||||||
|
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
||||||
|
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
||||||
|
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||||
|
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||||
|
///
|
||||||
|
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||||
|
/// (strong_count = 1).
|
||||||
|
pub(crate) buffers: AllocatedBuffers,
|
||||||
|
/// Seed for random number generation.
|
||||||
|
pub(crate) seed: Arc<Mutex<Buffer>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for MetalDevice {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "MetalDevice({:?})", self.id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for MetalDevice {
|
||||||
|
type Target = metal::DeviceRef;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetalDevice {
|
||||||
|
pub fn id(&self) -> DeviceId {
|
||||||
|
self.id
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn metal_device(&self) -> &metal::Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_queue(&self) -> &CommandQueue {
|
||||||
|
&self.command_queue
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||||
|
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||||
|
let mut command_buffer = command_buffer_lock.to_owned();
|
||||||
|
let mut index = self
|
||||||
|
.command_buffer_index
|
||||||
|
.try_write()
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
if *index > self.compute_per_buffer {
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
*command_buffer_lock = command_buffer.clone();
|
||||||
|
*index = 0;
|
||||||
|
|
||||||
|
self.drop_unused_buffers()?;
|
||||||
|
}
|
||||||
|
*index += 1;
|
||||||
|
Ok(command_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wait_until_completed(&self) -> Result<()> {
|
||||||
|
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||||
|
match command_buffer.status() {
|
||||||
|
metal::MTLCommandBufferStatus::Committed
|
||||||
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
|
| metal::MTLCommandBufferStatus::Completed => {
|
||||||
|
panic!("Already committed");
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn kernels(&self) -> &Kernels {
|
||||||
|
&self.kernels
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn device(&self) -> &metal::Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new buffer (not necessarily zeroed).
|
||||||
|
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
|
/// This means the buffer data cannot be read on the CPU directly.
|
||||||
|
///
|
||||||
|
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||||
|
pub fn new_buffer(
|
||||||
|
&self,
|
||||||
|
element_count: usize,
|
||||||
|
dtype: DType,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Arc<Buffer>> {
|
||||||
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
|
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new buffer (not necessarily zeroed).
|
||||||
|
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
|
/// This means the buffer can be read on the CPU but will require manual
|
||||||
|
/// synchronization when the CPU memory is modified
|
||||||
|
/// Used as a bridge to gather data back from the GPU
|
||||||
|
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||||
|
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new buffer from data.
|
||||||
|
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||||
|
///
|
||||||
|
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||||
|
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||||
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||||
|
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||||
|
let new_buffer = self.device.new_buffer_with_data(
|
||||||
|
data.as_ptr() as *const c_void,
|
||||||
|
size,
|
||||||
|
MTLResourceOptions::StorageModeManaged,
|
||||||
|
);
|
||||||
|
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||||
|
let subbuffers = buffers
|
||||||
|
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||||
|
.or_insert(vec![]);
|
||||||
|
|
||||||
|
let new_buffer = Arc::new(new_buffer);
|
||||||
|
subbuffers.push(new_buffer.clone());
|
||||||
|
Ok(new_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||||
|
let buffer = self.allocate_buffer(
|
||||||
|
size_in_bytes as NSUInteger,
|
||||||
|
MTLResourceOptions::StorageModePrivate,
|
||||||
|
"allocate_zeros",
|
||||||
|
)?;
|
||||||
|
let command_buffer = self.command_buffer()?;
|
||||||
|
command_buffer.set_label("zeros");
|
||||||
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
|
blit.fill_buffer(
|
||||||
|
&buffer,
|
||||||
|
metal::NSRange {
|
||||||
|
location: 0,
|
||||||
|
length: buffer.length(),
|
||||||
|
},
|
||||||
|
0,
|
||||||
|
);
|
||||||
|
blit.end_encoding();
|
||||||
|
Ok(buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn find_available_buffer(
|
||||||
|
&self,
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
buffers: &RwLockWriteGuard<BufferMap>,
|
||||||
|
) -> Option<Arc<Buffer>> {
|
||||||
|
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||||
|
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||||
|
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||||
|
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||||
|
for sub in subbuffers {
|
||||||
|
if Arc::strong_count(sub) == 1 {
|
||||||
|
best_buffer = Some(sub);
|
||||||
|
best_buffer_size = *buffer_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
best_buffer.cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn drop_unused_buffers(&self) -> Result<()> {
|
||||||
|
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||||
|
for subbuffers in buffers.values_mut() {
|
||||||
|
let newbuffers = subbuffers
|
||||||
|
.iter()
|
||||||
|
.filter(|s| Arc::strong_count(*s) > 1)
|
||||||
|
.map(Arc::clone)
|
||||||
|
.collect();
|
||||||
|
*subbuffers = newbuffers;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The critical allocator algorithm
|
||||||
|
fn allocate_buffer(
|
||||||
|
&self,
|
||||||
|
size: NSUInteger,
|
||||||
|
option: MTLResourceOptions,
|
||||||
|
_name: &str,
|
||||||
|
) -> Result<Arc<Buffer>> {
|
||||||
|
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||||
|
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||||
|
// Cloning also ensures we increment the strong count
|
||||||
|
return Ok(b.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let size = buf_size(size);
|
||||||
|
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||||
|
|
||||||
|
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||||
|
let new_buffer = Arc::new(new_buffer);
|
||||||
|
subbuffers.push(new_buffer.clone());
|
||||||
|
|
||||||
|
Ok(new_buffer)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a metal GPU capture trace on [`path`].
|
||||||
|
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||||
|
let capture = metal::CaptureManager::shared();
|
||||||
|
let descriptor = metal::CaptureDescriptor::new();
|
||||||
|
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||||
|
descriptor.set_capture_device(self);
|
||||||
|
descriptor.set_output_url(path);
|
||||||
|
|
||||||
|
capture
|
||||||
|
.start_capture(&descriptor)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||||
|
(size - 1).next_power_of_two() as NSUInteger
|
||||||
|
}
|
@ -4,24 +4,13 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
|||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels::CallConvTranspose2dCfg;
|
use candle_metal_kernels::CallConvTranspose2dCfg;
|
||||||
use candle_metal_kernels::Kernels;
|
use candle_metal_kernels::Kernels;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::path::Path;
|
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
|
|
||||||
|
|
||||||
/// Unique identifier for cuda devices.
|
mod device;
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
pub use device::{DeviceId, MetalDevice};
|
||||||
pub struct DeviceId(usize);
|
|
||||||
|
|
||||||
impl DeviceId {
|
|
||||||
fn new() -> Self {
|
|
||||||
// https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805
|
|
||||||
use std::sync::atomic;
|
|
||||||
static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1);
|
|
||||||
Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Simple way to catch lock error without
|
/// Simple way to catch lock error without
|
||||||
/// depending on T
|
/// depending on T
|
||||||
@ -49,13 +38,6 @@ pub enum MetalError {
|
|||||||
Message(String),
|
Message(String),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
||||||
|
|
||||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
|
||||||
MatMulNonContiguous {
|
|
||||||
lhs_stride: Vec<usize>,
|
|
||||||
rhs_stride: Vec<usize>,
|
|
||||||
mnk: (usize, usize, usize),
|
|
||||||
},
|
|
||||||
#[error("{0:?}")]
|
#[error("{0:?}")]
|
||||||
LockError(LockError),
|
LockError(LockError),
|
||||||
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
#[error("{msg}, expected: {expected:?}, got: {got:?}")]
|
||||||
@ -72,267 +54,6 @@ impl From<String> for MetalError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
|
||||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct MetalDevice {
|
|
||||||
/// Unique identifier, the registryID is not sufficient as it identifies the GPU rather than
|
|
||||||
/// the device itself.
|
|
||||||
id: DeviceId,
|
|
||||||
|
|
||||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
|
||||||
device: metal::Device,
|
|
||||||
|
|
||||||
/// Single command queue for the entire device.
|
|
||||||
command_queue: CommandQueue,
|
|
||||||
/// One command buffer at a time.
|
|
||||||
/// The scheduler works by allowing multiple
|
|
||||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
|
||||||
/// on a single command buffer. Using a single command buffer would be fastest on the GPU but
|
|
||||||
/// prevents overlapping of CPU and GPU commands (because command buffer needs to be committed
|
|
||||||
/// to start to work).
|
|
||||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
|
||||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
|
||||||
/// command buffer2 starts (or there are metal bugs there)
|
|
||||||
command_buffer: Arc<RwLock<CommandBuffer>>,
|
|
||||||
/// Keeps track of the current amount of compute command encoders on the current
|
|
||||||
/// command buffer
|
|
||||||
/// Arc, RwLock because of the interior mutability.
|
|
||||||
command_buffer_index: Arc<RwLock<usize>>,
|
|
||||||
/// The maximum amount of [compute command encoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc) per [command buffer](https://developer.apple.com/documentation/metal/mtlcommandbuffer?language=objc)
|
|
||||||
compute_per_buffer: usize,
|
|
||||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
|
||||||
/// Heavily used by [`candle_metal_kernels`]
|
|
||||||
kernels: Arc<Kernels>,
|
|
||||||
/// Simple allocator struct.
|
|
||||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
|
||||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
|
||||||
/// (could be linked to FFI communication overhead).
|
|
||||||
///
|
|
||||||
/// Whenever a buffer has a strong_count==1, we can reuse it, it means it was dropped in the
|
|
||||||
/// graph calculation, and only we the allocator kept a reference to it, therefore it's free
|
|
||||||
/// to be reused. However, in order for this to work, we need to guarantee the order of
|
|
||||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
|
||||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
|
||||||
///
|
|
||||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
|
||||||
/// (strong_count = 1).
|
|
||||||
buffers: AllocatedBuffers,
|
|
||||||
/// Seed for random number generation.
|
|
||||||
seed: Arc<Mutex<Buffer>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "MetalDevice({:?})", self.id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::ops::Deref for MetalDevice {
|
|
||||||
type Target = metal::DeviceRef;
|
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MetalDevice {
|
|
||||||
pub fn id(&self) -> DeviceId {
|
|
||||||
self.id
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn metal_device(&self) -> &metal::Device {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
|
||||||
&self.command_queue
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
|
||||||
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
|
||||||
let mut command_buffer = command_buffer_lock.to_owned();
|
|
||||||
let mut index = self
|
|
||||||
.command_buffer_index
|
|
||||||
.try_write()
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
if *index > self.compute_per_buffer {
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
*command_buffer_lock = command_buffer.clone();
|
|
||||||
*index = 0;
|
|
||||||
|
|
||||||
self.drop_unused_buffers()?;
|
|
||||||
}
|
|
||||||
*index += 1;
|
|
||||||
Ok(command_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) -> Result<()> {
|
|
||||||
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
|
||||||
match command_buffer.status() {
|
|
||||||
metal::MTLCommandBufferStatus::Committed
|
|
||||||
| metal::MTLCommandBufferStatus::Scheduled
|
|
||||||
| metal::MTLCommandBufferStatus::Completed => {
|
|
||||||
panic!("Already committed");
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn kernels(&self) -> &Kernels {
|
|
||||||
&self.kernels
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &metal::Device {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new buffer (not necessarily zeroed).
|
|
||||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
|
||||||
/// This means the buffer data cannot be read on the CPU directly.
|
|
||||||
///
|
|
||||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
|
||||||
pub fn new_buffer(
|
|
||||||
&self,
|
|
||||||
element_count: usize,
|
|
||||||
dtype: DType,
|
|
||||||
name: &str,
|
|
||||||
) -> Result<Arc<Buffer>> {
|
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
|
||||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new buffer (not necessarily zeroed).
|
|
||||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
|
||||||
/// This means the buffer can be read on the CPU but will require manual
|
|
||||||
/// synchronization when the CPU memory is modified
|
|
||||||
/// Used as a bridge to gather data back from the GPU
|
|
||||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
|
||||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a new buffer from data.
|
|
||||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
|
||||||
///
|
|
||||||
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
|
||||||
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
|
||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
|
||||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
|
||||||
let new_buffer = self.device.new_buffer_with_data(
|
|
||||||
data.as_ptr() as *const c_void,
|
|
||||||
size,
|
|
||||||
MTLResourceOptions::StorageModeManaged,
|
|
||||||
);
|
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
|
||||||
let subbuffers = buffers
|
|
||||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
|
||||||
.or_insert(vec![]);
|
|
||||||
|
|
||||||
let new_buffer = Arc::new(new_buffer);
|
|
||||||
subbuffers.push(new_buffer.clone());
|
|
||||||
Ok(new_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
|
||||||
let buffer = self.allocate_buffer(
|
|
||||||
size_in_bytes as NSUInteger,
|
|
||||||
MTLResourceOptions::StorageModePrivate,
|
|
||||||
"allocate_zeros",
|
|
||||||
)?;
|
|
||||||
let command_buffer = self.command_buffer()?;
|
|
||||||
command_buffer.set_label("zeros");
|
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
|
||||||
blit.fill_buffer(
|
|
||||||
&buffer,
|
|
||||||
metal::NSRange {
|
|
||||||
location: 0,
|
|
||||||
length: buffer.length(),
|
|
||||||
},
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
blit.end_encoding();
|
|
||||||
Ok(buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn find_available_buffer(
|
|
||||||
&self,
|
|
||||||
size: NSUInteger,
|
|
||||||
option: MTLResourceOptions,
|
|
||||||
buffers: &RwLockWriteGuard<BufferMap>,
|
|
||||||
) -> Option<Arc<Buffer>> {
|
|
||||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
|
||||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
|
||||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
|
||||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
|
||||||
for sub in subbuffers {
|
|
||||||
if Arc::strong_count(sub) == 1 {
|
|
||||||
best_buffer = Some(sub);
|
|
||||||
best_buffer_size = *buffer_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
best_buffer.cloned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn drop_unused_buffers(&self) -> Result<()> {
|
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
|
||||||
for subbuffers in buffers.values_mut() {
|
|
||||||
let newbuffers = subbuffers
|
|
||||||
.iter()
|
|
||||||
.filter(|s| Arc::strong_count(*s) > 1)
|
|
||||||
.map(Arc::clone)
|
|
||||||
.collect();
|
|
||||||
*subbuffers = newbuffers;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The critical allocator algorithm
|
|
||||||
fn allocate_buffer(
|
|
||||||
&self,
|
|
||||||
size: NSUInteger,
|
|
||||||
option: MTLResourceOptions,
|
|
||||||
_name: &str,
|
|
||||||
) -> Result<Arc<Buffer>> {
|
|
||||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
|
||||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
|
||||||
// Cloning also ensures we increment the strong count
|
|
||||||
return Ok(b.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let size = buf_size(size);
|
|
||||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
|
||||||
|
|
||||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
|
||||||
let new_buffer = Arc::new(new_buffer);
|
|
||||||
subbuffers.push(new_buffer.clone());
|
|
||||||
|
|
||||||
Ok(new_buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a metal GPU capture trace on [`path`].
|
|
||||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
|
||||||
let capture = metal::CaptureManager::shared();
|
|
||||||
let descriptor = metal::CaptureDescriptor::new();
|
|
||||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
|
||||||
descriptor.set_capture_device(self);
|
|
||||||
descriptor.set_output_url(path);
|
|
||||||
|
|
||||||
capture
|
|
||||||
.start_capture(&descriptor)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MetalStorage {
|
pub struct MetalStorage {
|
||||||
/// The actual buffer containing the data.
|
/// The actual buffer containing the data.
|
||||||
@ -2055,10 +1776,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
|
||||||
(size - 1).next_power_of_two() as NSUInteger
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||||
let ptr = buffer.contents() as *const T;
|
let ptr = buffer.contents() as *const T;
|
||||||
assert!(!ptr.is_null());
|
assert!(!ptr.is_null());
|
Reference in New Issue
Block a user