mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Fix comments.
This commit is contained in:
@ -61,7 +61,7 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
@ -13,7 +13,7 @@ readme = "README.md"
|
|||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
|
||||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
|
||||||
metal = { workspace = true, optional = true}
|
metal = { workspace = true, optional = true}
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
gemm = { workspace = true }
|
gemm = { workspace = true }
|
||||||
|
@ -54,10 +54,6 @@ impl std::ops::Deref for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MetalDevice {
|
impl MetalDevice {
|
||||||
// pub fn metal_device(&self) -> &metal::DeviceRef {
|
|
||||||
// self.device.as_ref()
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn id(&self) -> NSUInteger {
|
pub fn id(&self) -> NSUInteger {
|
||||||
self.registry_id()
|
self.registry_id()
|
||||||
}
|
}
|
||||||
@ -76,7 +72,6 @@ impl MetalDevice {
|
|||||||
|
|
||||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
// debug!("Allocate 1 - buffer size {size}");
|
|
||||||
self.device
|
self.device
|
||||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
}
|
}
|
||||||
@ -105,28 +100,22 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
let length = self.buffer.length() as usize;
|
||||||
|
let size = self.dtype.size_in_bytes();
|
||||||
|
if length % size != 0 {
|
||||||
|
crate::bail!(
|
||||||
|
"The Metal buffer length is not aligned with dtype {:?}",
|
||||||
|
self.dtype
|
||||||
|
);
|
||||||
|
}
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
DType::U8 => Ok(CpuStorage::U8(
|
DType::U8 => Ok(CpuStorage::U8(self.buffer.read_to_vec(length / size))),
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
|
DType::U32 => Ok(CpuStorage::U32(self.buffer.read_to_vec(length / size))),
|
||||||
)),
|
DType::I64 => Ok(CpuStorage::I64(self.buffer.read_to_vec(length / size))),
|
||||||
DType::U32 => Ok(CpuStorage::U32(
|
DType::F16 => Ok(CpuStorage::F16(self.buffer.read_to_vec(length / size))),
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
DType::BF16 => Ok(CpuStorage::BF16(self.buffer.read_to_vec(length / size))),
|
||||||
)),
|
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(length / size))),
|
||||||
DType::I64 => Ok(CpuStorage::I64(
|
DType::F64 => Ok(CpuStorage::F64(self.buffer.read_to_vec(length / size))),
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
|
||||||
)),
|
|
||||||
DType::F16 => Ok(CpuStorage::F16(
|
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
|
||||||
)),
|
|
||||||
DType::BF16 => Ok(CpuStorage::BF16(
|
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
|
||||||
)),
|
|
||||||
DType::F32 => Ok(CpuStorage::F32(
|
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
|
||||||
)),
|
|
||||||
DType::F64 => Ok(CpuStorage::F64(
|
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
|
||||||
)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,9 +126,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
|
|
||||||
assert!(layout.is_contiguous());
|
if layout.is_contiguous() || layout.start_offset() != 0|| dtype != DType::F32{
|
||||||
assert!(layout.start_offset() == 0);
|
crate::bail!("Not contiguous, non-f32 affine is not implemented yet.");
|
||||||
assert_eq!(dtype, DType::F32);
|
}
|
||||||
|
|
||||||
let mut buffer = device.new_buffer(el, self.dtype);
|
let mut buffer = device.new_buffer(el, self.dtype);
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
@ -153,7 +142,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
@ -164,18 +153,18 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("powf metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("elu metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
assert!(sum_dims.len() == 1);
|
|
||||||
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
if !(sum_dims.len() == 1 && sum_dims[0] == layout.shape().rank() - 1 && layout.is_contiguous() && layout.start_offset() == 0){
|
||||||
assert!(layout.is_contiguous());
|
crate::bail!("Non contiguous reduce op not supported yet");
|
||||||
assert!(layout.start_offset() == 0);
|
}
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
let src_stride = layout.stride();
|
let src_stride = layout.stride();
|
||||||
let src_dims = layout.shape().dims();
|
let src_dims = layout.shape().dims();
|
||||||
@ -204,7 +193,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
|
(ReduceOp::Max, DType::F32) => ("fast_max_float", true, false),
|
||||||
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
|
(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_float", true, true),
|
||||||
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
|
(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_float", true, true),
|
||||||
_ => todo!("Reduce op for non float"),
|
_ => crate::bail!("Reduce op for non float"),
|
||||||
};
|
};
|
||||||
if check_empty && layout.shape().elem_count() == 0 {
|
if check_empty && layout.shape().elem_count() == 0 {
|
||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
@ -234,7 +223,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("cmp metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||||
@ -246,7 +235,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
if layout.is_contiguous() {
|
if layout.is_contiguous() {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||||
(left, right) => todo!("to dtype {left:?} - {right:?}"),
|
(left, right) => crate::bail!("to dtype {left:?} - {right:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_cast_contiguous(
|
candle_metal_kernels::call_cast_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -259,7 +248,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
todo!(
|
crate::bail!(
|
||||||
"TODO Implement the kernel calling cast {:?}-{:?}",
|
"TODO Implement the kernel calling cast {:?}-{:?}",
|
||||||
self.dtype,
|
self.dtype,
|
||||||
dtype
|
dtype
|
||||||
@ -293,7 +282,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_contiguous(
|
candle_metal_kernels::call_unary_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -306,7 +295,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
crate::bail!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||||
}
|
}
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
@ -344,7 +333,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||||
("div", DType::F32) => contiguous::div::FLOAT,
|
("div", DType::F32) => contiguous::div::FLOAT,
|
||||||
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_binary_contiguous(
|
candle_metal_kernels::call_binary_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -365,7 +354,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("bsub", DType::F32) => strided::sub::FLOAT,
|
("bsub", DType::F32) => strided::sub::FLOAT,
|
||||||
("bmul", DType::F32) => strided::mul::FLOAT,
|
("bmul", DType::F32) => strided::mul::FLOAT,
|
||||||
("bdiv", DType::F32) => strided::div::FLOAT,
|
("bdiv", DType::F32) => strided::div::FLOAT,
|
||||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_binary_strided(
|
candle_metal_kernels::call_binary_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
@ -442,7 +431,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
_kernel_l: &Layout,
|
_kernel_l: &Layout,
|
||||||
_params: &ParamsConv1D,
|
_params: &ParamsConv1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("conv1d metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose1d(
|
fn conv_transpose1d(
|
||||||
@ -452,7 +441,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
_kernel_l: &Layout,
|
_kernel_l: &Layout,
|
||||||
_params: &ParamsConvTranspose1D,
|
_params: &ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("conv_transpose1d metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv2d(
|
fn conv2d(
|
||||||
@ -462,7 +451,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
_kernel_l: &Layout,
|
_kernel_l: &Layout,
|
||||||
_params: &ParamsConv2D,
|
_params: &ParamsConv2D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("conv2d metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv_transpose2d(
|
fn conv_transpose2d(
|
||||||
@ -472,27 +461,27 @@ impl BackendStorage for MetalStorage {
|
|||||||
_kernel_l: &Layout,
|
_kernel_l: &Layout,
|
||||||
_params: &ParamsConvTranspose2D,
|
_params: &ParamsConvTranspose2D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("conv_tranpose2d metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("avg_pool2d metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("max_pool2d metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("upsample_nearest1d metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("upsample_nearest2d metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("gather metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scatter_add(
|
fn scatter_add(
|
||||||
@ -504,14 +493,13 @@ impl BackendStorage for MetalStorage {
|
|||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: usize,
|
_: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("scatter_add metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
assert!(src_l.is_contiguous());
|
if !(src_l.is_contiguous() && src_l.start_offset() == 0 && ids_l.is_contiguous() && ids_l.start_offset() == 0){
|
||||||
assert!(src_l.start_offset() == 0);
|
crate::bail!("Non contiguous index select not implemented");
|
||||||
assert!(ids_l.is_contiguous());
|
}
|
||||||
assert!(ids_l.start_offset() == 0);
|
|
||||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
let ids_el = ids_l.shape().elem_count();
|
let ids_el = ids_l.shape().elem_count();
|
||||||
@ -519,10 +507,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let device = self.device();
|
let device = self.device();
|
||||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
let out = self.to_cpu_storage().unwrap();
|
let out = self.to_cpu_storage()?;
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
candle_metal_kernels::call_index_select(
|
candle_metal_kernels::call_index_select(
|
||||||
@ -556,7 +544,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: usize,
|
_: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
todo!()
|
crate::bail!("index_add metal")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn matmul(
|
fn matmul(
|
||||||
@ -666,11 +654,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
// let left = self.buffer.read_to_vec::<f32>(10);
|
|
||||||
// let right = rhs.buffer.read_to_vec::<f32>(10);
|
|
||||||
// let out = out_buffer.read_to_vec::<f32>(40);
|
|
||||||
// todo!("Out {left:?} {right:?} {out:?}");
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
@ -681,7 +664,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
let src_shape = src_l.shape();
|
let src_shape = src_l.shape();
|
||||||
let el_count = src_shape.elem_count();
|
let el_count = src_shape.elem_count();
|
||||||
// todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}");
|
|
||||||
if el_count == 0 {
|
if el_count == 0 {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
@ -690,7 +672,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||||
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
dtype => crate::bail!("copy_strided not implemented for {dtype:?}"),
|
||||||
};
|
};
|
||||||
candle_metal_kernels::call_unary_strided(
|
candle_metal_kernels::call_unary_strided(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
@ -741,7 +723,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||||
todo!("set_seed")
|
crate::bail!("set_seed")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.3.0"
|
version = "0.3.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
@ -10,7 +10,7 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
#![allow(clippy::too_many_arguments)]
|
|
||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineDescriptor,
|
||||||
ComputePipelineState, Device, Function, Library, MTLSize,
|
ComputePipelineState, Device, Function, Library, MTLSize,
|
||||||
@ -156,14 +155,6 @@ pub mod binary {
|
|||||||
ops!(add, sub, mul, div);
|
ops!(add, sub, mul, div);
|
||||||
}
|
}
|
||||||
|
|
||||||
// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
|
||||||
// let mut l = HashMap::new();
|
|
||||||
// l.insert("affine", AFFINE);
|
|
||||||
// l.insert("indexing", INDEXING);
|
|
||||||
// l.insert("unary", UNARY);
|
|
||||||
// l
|
|
||||||
// });
|
|
||||||
//
|
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum MetalKernelError {
|
pub enum MetalKernelError {
|
||||||
#[error("Could not lock kernel map: {0}")]
|
#[error("Could not lock kernel map: {0}")]
|
||||||
@ -197,21 +188,7 @@ impl Kernels {
|
|||||||
Self { libraries, funcs }
|
Self { libraries, funcs }
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
|
||||||
// let kernels = Self::new();
|
|
||||||
// kernels.load_libraries(device)?;
|
|
||||||
// Ok(kernels)
|
|
||||||
// }
|
|
||||||
|
|
||||||
// fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
|
|
||||||
// for name in LIBRARY_SOURCES.keys() {
|
|
||||||
// self.load_library(device, name)?;
|
|
||||||
// }
|
|
||||||
// Ok(())
|
|
||||||
// }
|
|
||||||
|
|
||||||
fn get_library_source(&self, source: Source) -> &'static str {
|
fn get_library_source(&self, source: Source) -> &'static str {
|
||||||
// LIBRARY_SOURCES.get(name).cloned()
|
|
||||||
match source {
|
match source {
|
||||||
Source::Affine => AFFINE,
|
Source::Affine => AFFINE,
|
||||||
Source::Unary => UNARY,
|
Source::Unary => UNARY,
|
||||||
@ -261,6 +238,7 @@ impl Kernels {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_unary_contiguous(
|
pub fn call_unary_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -270,8 +248,6 @@ pub fn call_unary_contiguous(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
|
||||||
// assert_eq!(input.length(), output.length());
|
|
||||||
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
@ -292,6 +268,8 @@ pub fn call_unary_contiguous(
|
|||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_unary_strided(
|
pub fn call_unary_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -339,6 +317,7 @@ pub fn call_unary_strided(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_binary_contiguous(
|
pub fn call_binary_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -349,8 +328,6 @@ pub fn call_binary_contiguous(
|
|||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
|
||||||
// assert_eq!(input.length(), output.length());
|
|
||||||
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
|
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
@ -373,6 +350,7 @@ pub fn call_binary_contiguous(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_binary_strided(
|
pub fn call_binary_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -425,6 +403,7 @@ pub fn call_binary_strided(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_cast_contiguous(
|
pub fn call_cast_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -434,8 +413,6 @@ pub fn call_cast_contiguous(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
|
||||||
// assert_eq!(input.length(), output.length());
|
|
||||||
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
@ -458,6 +435,7 @@ pub fn call_cast_contiguous(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_reduce_contiguous(
|
pub fn call_reduce_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -508,6 +486,7 @@ pub fn call_reduce_contiguous(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_last_softmax(
|
pub fn call_last_softmax(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -543,7 +522,6 @@ pub fn call_last_softmax(
|
|||||||
|
|
||||||
let width = std::cmp::min(
|
let width = std::cmp::min(
|
||||||
pipeline.max_total_threads_per_threadgroup(),
|
pipeline.max_total_threads_per_threadgroup(),
|
||||||
// (elements_to_sum as u64 + 2 - 1) / 2,
|
|
||||||
elements_to_sum as u64,
|
elements_to_sum as u64,
|
||||||
)
|
)
|
||||||
.next_power_of_two();
|
.next_power_of_two();
|
||||||
@ -559,6 +537,7 @@ pub fn call_last_softmax(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_affine(
|
pub fn call_affine(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -590,6 +569,7 @@ pub fn call_affine(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_where_cond_strided(
|
pub fn call_where_cond_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -643,6 +623,7 @@ pub fn call_where_cond_strided(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_index_select(
|
pub fn call_index_select(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -813,7 +794,6 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn cos_f32_strided() {
|
fn cos_f32_strided() {
|
||||||
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
|
||||||
// Shape = [6], strides = [1];
|
|
||||||
let shape = vec![6];
|
let shape = vec![6];
|
||||||
let strides = vec![1];
|
let strides = vec![1];
|
||||||
let offset = 0;
|
let offset = 0;
|
||||||
|
Reference in New Issue
Block a user