mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Support for UG kernels. (#2579)
* Support for UG kernels. * Add a dedicated test.
This commit is contained in:
@ -70,6 +70,8 @@ tokenizers = { version = "0.19.1", default-features = false }
|
|||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
tracing-chrome = "0.7.1"
|
tracing-chrome = "0.7.1"
|
||||||
tracing-subscriber = "0.3.7"
|
tracing-subscriber = "0.3.7"
|
||||||
|
ug = "0.0.2"
|
||||||
|
ug-cuda = "0.0.2"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "1.1.1", default-features = false }
|
zip = { version = "1.1.1", default-features = false }
|
||||||
metal = { version = "0.27.0", features = ["mps"]}
|
metal = { version = "0.27.0", features = ["mps"]}
|
||||||
|
@ -28,6 +28,8 @@ rand_distr = { workspace = true }
|
|||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
|
ug = { workspace = true }
|
||||||
|
ug-cuda = { workspace = true, optional = true }
|
||||||
yoke = { workspace = true }
|
yoke = { workspace = true }
|
||||||
zip = { workspace = true }
|
zip = { workspace = true }
|
||||||
|
|
||||||
@ -39,7 +41,7 @@ criterion = { workspace = true }
|
|||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
cuda = ["cudarc", "dep:candle-kernels"]
|
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
|
||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
accelerate = ["dep:libc", "dep:accelerate-src"]
|
accelerate = ["dep:libc", "dep:accelerate-src"]
|
||||||
|
@ -51,6 +51,27 @@ impl CudaDevice {
|
|||||||
self.device.clone()
|
self.device.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn compile(
|
||||||
|
&self,
|
||||||
|
func_name: &'static str,
|
||||||
|
kernel: ug::lang::ssa::Kernel,
|
||||||
|
) -> Result<CudaFunction> {
|
||||||
|
let mut buf = vec![];
|
||||||
|
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
|
||||||
|
let cuda_code = String::from_utf8(buf)?;
|
||||||
|
let opts = cudarc::nvrtc::CompileOptions {
|
||||||
|
use_fast_math: Some(true),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
|
||||||
|
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
|
||||||
|
let func = match self.device.get_func("ug", func_name) {
|
||||||
|
Some(func) => func,
|
||||||
|
None => crate::bail!("unknown function ug::{func_name}"),
|
||||||
|
};
|
||||||
|
Ok(func)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn id(&self) -> DeviceId {
|
pub fn id(&self) -> DeviceId {
|
||||||
self.id
|
self.id
|
||||||
}
|
}
|
||||||
|
@ -375,3 +375,70 @@ impl Tensor {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct UgIOp1 {
|
||||||
|
name: &'static str,
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
func: cudarc::driver::CudaFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UgIOp1 {
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn new(
|
||||||
|
name: &'static str,
|
||||||
|
kernel: ug::lang::ssa::Kernel,
|
||||||
|
device: &crate::Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
let device = device.as_cuda_device()?;
|
||||||
|
let func = device.compile(name, kernel)?;
|
||||||
|
Ok(Self { name, func })
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
{
|
||||||
|
Ok(Self { name })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InplaceOp1 for UgIOp1 {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
|
||||||
|
crate::bail!("ug ops are only supported on cuda at the moment")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> {
|
||||||
|
crate::bail!("ug ops are only supported on cuda at the moment")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
|
||||||
|
use crate::cuda_backend::WrapErr;
|
||||||
|
use cudarc::driver::LaunchAsync;
|
||||||
|
|
||||||
|
let elem_count = layout.shape().elem_count();
|
||||||
|
// TODO: support more dtypes.
|
||||||
|
let sto = sto.as_cuda_slice::<f32>()?;
|
||||||
|
let sto = match layout.contiguous_offsets() {
|
||||||
|
None => crate::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => sto.slice(o1..o2),
|
||||||
|
};
|
||||||
|
let params = (&sto,);
|
||||||
|
let (g, b) = if elem_count % 32 == 0 {
|
||||||
|
(elem_count / 32, 32)
|
||||||
|
} else {
|
||||||
|
(elem_count, 1)
|
||||||
|
};
|
||||||
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
grid_dim: (g as u32, 1, 1),
|
||||||
|
block_dim: (b as u32, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
unsafe { self.func.clone().launch(cfg, params) }.w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -130,6 +130,14 @@ impl Device {
|
|||||||
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
|
||||||
|
match self {
|
||||||
|
Self::Cuda(d) => Ok(d),
|
||||||
|
Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
|
||||||
|
Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
|
pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
|
||||||
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
|
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
|
||||||
}
|
}
|
||||||
|
@ -165,6 +165,9 @@ pub enum Error {
|
|||||||
#[error("Metal error {0}")]
|
#[error("Metal error {0}")]
|
||||||
Metal(#[from] MetalError),
|
Metal(#[from] MetalError),
|
||||||
|
|
||||||
|
#[error(transparent)]
|
||||||
|
Ug(#[from] ug::Error),
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
TryFromIntError(#[from] core::num::TryFromIntError),
|
TryFromIntError(#[from] core::num::TryFromIntError),
|
||||||
|
|
||||||
@ -179,6 +182,10 @@ pub enum Error {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
ParseInt(#[from] std::num::ParseIntError),
|
ParseInt(#[from] std::num::ParseIntError),
|
||||||
|
|
||||||
|
/// Utf8 parse error.
|
||||||
|
#[error(transparent)]
|
||||||
|
FromUtf8(#[from] std::string::FromUtf8Error),
|
||||||
|
|
||||||
/// I/O error.
|
/// I/O error.
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Io(#[from] std::io::Error),
|
Io(#[from] std::io::Error),
|
||||||
|
@ -77,7 +77,7 @@ mod variable;
|
|||||||
pub use cuda_backend::cudnn;
|
pub use cuda_backend::cudnn;
|
||||||
|
|
||||||
pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
pub use cpu_backend::{CpuStorage, CpuStorageRef};
|
||||||
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
|
@ -143,3 +143,33 @@ fn inplace_op1() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
#[allow(clippy::approx_constant)]
|
||||||
|
#[test]
|
||||||
|
fn ug_op() -> Result<()> {
|
||||||
|
let kernel = {
|
||||||
|
use ug::lang::op;
|
||||||
|
|
||||||
|
let layout = ug::Layout::from_shape(&[12]);
|
||||||
|
let ptr = op::Arg::ptr(ug::DType::F32);
|
||||||
|
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
|
||||||
|
let src = op::unary(op::UnaryOp::Exp, src)?;
|
||||||
|
let st = op::store(ptr.id(), layout, src)?;
|
||||||
|
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
|
||||||
|
let opts: ug::lower_op::Opts = Default::default();
|
||||||
|
kernel.lower(&opts.with_global(0, 12))?
|
||||||
|
};
|
||||||
|
let device = Device::new_cuda(0)?;
|
||||||
|
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
|
||||||
|
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
|
||||||
|
t.inplace_op1(&op)?;
|
||||||
|
assert_eq!(
|
||||||
|
to_vec1_round(&t, 4)?,
|
||||||
|
&[
|
||||||
|
1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578,
|
||||||
|
8103.0806, 22026.469, 59874.133
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user