mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
UG metal integration. (#2580)
This commit is contained in:
@ -72,6 +72,7 @@ tracing-chrome = "0.7.1"
|
|||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
ug = "0.0.2"
|
ug = "0.0.2"
|
||||||
ug-cuda = "0.0.2"
|
ug-cuda = "0.0.2"
|
||||||
|
ug-metal = "0.0.2"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "1.1.1", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
@ -30,6 +30,7 @@ safetensors = { workspace = true }
|
|||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
ug = { workspace = true }
|
ug = { workspace = true }
|
||||||
ug-cuda = { workspace = true, optional = true }
|
ug-cuda = { workspace = true, optional = true }
|
||||||
|
ug-metal = { workspace = true, optional = true }
|
||||||
yoke = { workspace = true }
|
yoke = { workspace = true }
|
||||||
zip = { workspace = true }
|
zip = { workspace = true }
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
|
|||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
metal = ["dep:metal", "dep:candle-metal-kernels"]
|
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "bench_main"
|
name = "bench_main"
|
||||||
|
@ -380,6 +380,8 @@ pub struct UgIOp1 {
|
|||||||
name: &'static str,
|
name: &'static str,
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
func: cudarc::driver::CudaFunction,
|
func: cudarc::driver::CudaFunction,
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
func: metal::ComputePipelineState,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UgIOp1 {
|
impl UgIOp1 {
|
||||||
@ -395,7 +397,13 @@ impl UgIOp1 {
|
|||||||
let func = device.compile(name, kernel)?;
|
let func = device.compile(name, kernel)?;
|
||||||
Ok(Self { name, func })
|
Ok(Self { name, func })
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "cuda"))]
|
#[cfg(feature = "metal")]
|
||||||
|
{
|
||||||
|
let device = device.as_metal_device()?;
|
||||||
|
let func = device.compile(name, kernel)?;
|
||||||
|
Ok(Self { name, func })
|
||||||
|
}
|
||||||
|
#[cfg(not(any(feature = "cuda", feature = "metal")))]
|
||||||
{
|
{
|
||||||
Ok(Self { name })
|
Ok(Self { name })
|
||||||
}
|
}
|
||||||
@ -408,11 +416,43 @@ impl InplaceOp1 for UgIOp1 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
|
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
|
||||||
crate::bail!("ug ops are only supported on cuda at the moment")
|
crate::bail!("ug ops are only supported on metal/cuda at the moment")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> {
|
#[cfg(feature = "metal")]
|
||||||
crate::bail!("ug ops are only supported on cuda at the moment")
|
fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
|
||||||
|
use crate::backend::BackendStorage;
|
||||||
|
use candle_metal_kernels::utils::EncoderProvider;
|
||||||
|
|
||||||
|
let elem_count = layout.shape().elem_count();
|
||||||
|
if sto.dtype() != crate::DType::F32 {
|
||||||
|
// TODO: support more dtypes.
|
||||||
|
crate::bail!("input is not a f32 tensor")
|
||||||
|
}
|
||||||
|
let device = sto.device();
|
||||||
|
println!("here");
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
let command_buffer = &command_buffer;
|
||||||
|
let encoder = command_buffer.encoder();
|
||||||
|
let encoder = encoder.as_ref();
|
||||||
|
encoder.set_compute_pipeline_state(&self.func);
|
||||||
|
let (g, b) = if elem_count % 32 == 0 {
|
||||||
|
(elem_count / 32, 32)
|
||||||
|
} else {
|
||||||
|
(elem_count, 1)
|
||||||
|
};
|
||||||
|
let grid_dims = metal::MTLSize {
|
||||||
|
width: g as u64,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
|
||||||
|
candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));
|
||||||
|
|
||||||
|
encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
|
@ -138,6 +138,14 @@ impl Device {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
|
||||||
|
match self {
|
||||||
|
Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
|
||||||
|
Self::Cpu => crate::bail!("expected a metal device, got cpu"),
|
||||||
|
Self::Metal(d) => Ok(d),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
|
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
|
||||||
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
|
||||||
}
|
}
|
||||||
|
@ -144,6 +144,28 @@ impl MetalDevice {
|
|||||||
self.use_mlx_mm = use_mlx_mm
|
self.use_mlx_mm = use_mlx_mm
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn compile(
|
||||||
|
&self,
|
||||||
|
func_name: &'static str,
|
||||||
|
kernel: ug::lang::ssa::Kernel,
|
||||||
|
) -> Result<metal::ComputePipelineState> {
|
||||||
|
let mut buf = vec![];
|
||||||
|
ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||||
|
let metal_code = String::from_utf8(buf)?;
|
||||||
|
let lib = self
|
||||||
|
.device
|
||||||
|
.new_library_with_source(&metal_code, &metal::CompileOptions::new())
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
let func = lib
|
||||||
|
.get_function(func_name, None)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
let pl = self
|
||||||
|
.device
|
||||||
|
.new_compute_pipeline_state_with_function(&func)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
Ok(pl)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> DeviceId {
|
pub fn id(&self) -> DeviceId {
|
||||||
self.id
|
self.id
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ fn inplace_op1() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(any(feature = "cuda", feature = "metal"))]
|
||||||
#[allow(clippy::approx_constant)]
|
#[allow(clippy::approx_constant)]
|
||||||
#[test]
|
#[test]
|
||||||
fn ug_op() -> Result<()> {
|
fn ug_op() -> Result<()> {
|
||||||
@ -160,15 +160,21 @@ fn ug_op() -> Result<()> {
|
|||||||
let opts: ug::lower_op::Opts = Default::default();
|
let opts: ug::lower_op::Opts = Default::default();
|
||||||
kernel.lower(&opts.with_global(0, 12))?
|
kernel.lower(&opts.with_global(0, 12))?
|
||||||
};
|
};
|
||||||
let device = Device::new_cuda(0)?;
|
let device = if candle_core::utils::cuda_is_available() {
|
||||||
|
Device::new_cuda(0)?
|
||||||
|
} else if candle_core::utils::metal_is_available() {
|
||||||
|
Device::new_metal(0)?
|
||||||
|
} else {
|
||||||
|
candle_core::bail!("metal/cuda is mandatory for this test")
|
||||||
|
};
|
||||||
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
|
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
|
||||||
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
|
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
|
||||||
t.inplace_op1(&op)?;
|
t.inplace_op1(&op)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec1_round(&t, 4)?,
|
to_vec1_round(&t, 2)?,
|
||||||
&[
|
&[
|
||||||
1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578,
|
1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47,
|
||||||
8103.0806, 22026.469, 59874.133
|
59874.13
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -6,7 +6,7 @@ use std::collections::HashMap;
|
|||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
mod utils;
|
pub mod utils;
|
||||||
pub use utils::BufferOffset;
|
pub use utils::BufferOffset;
|
||||||
use utils::{get_block_dims, linear_split, EncoderProvider};
|
use utils::{get_block_dims, linear_split, EncoderProvider};
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M
|
|||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
||||||
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
||||||
let mut pows0 = 0u64;
|
let mut pows0 = 0u64;
|
||||||
let mut pows1 = 0u64;
|
let mut pows1 = 0u64;
|
||||||
let mut pows2 = 0u64;
|
let mut pows2 = 0u64;
|
||||||
@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn set_param<P: EncoderParam>(
|
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
||||||
encoder: &ComputeCommandEncoderRef,
|
|
||||||
position: u64,
|
|
||||||
data: P,
|
|
||||||
) {
|
|
||||||
<P as EncoderParam>::set_param(encoder, position, data)
|
<P as EncoderParam>::set_param(encoder, position, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper functions to create the various objects on the compute command encoder
|
/// Helper functions to create the various objects on the compute command encoder
|
||||||
/// on a single line.
|
/// on a single line.
|
||||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||||
pub(crate) trait EncoderParam {
|
pub trait EncoderParam {
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||||
}
|
}
|
||||||
macro_rules! primitive {
|
macro_rules! primitive {
|
||||||
|
Reference in New Issue
Block a user