mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Simplify the pattern matching logic in the cuda backend.
This commit is contained in:
@ -487,6 +487,7 @@ fn main() -> Result<()> {
|
|||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
let input = Tensor::new(ctxt, &device)?;
|
||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
@ -496,6 +497,7 @@ fn main() -> Result<()> {
|
|||||||
let next_token = distr.sample(&mut rng) as u32;
|
let next_token = distr.sample(&mut rng) as u32;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
|
println!("> {:?}", start_gen.elapsed());
|
||||||
println!(
|
println!(
|
||||||
"{} token: {} '{}'",
|
"{} token: {} '{}'",
|
||||||
index + 1,
|
index + 1,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::{CpuStorage, DType, Layout, Shape};
|
use crate::{CpuStorage, DType, Layout, Shape, WithDType};
|
||||||
use candle_kernels as kernels;
|
use candle_kernels as kernels;
|
||||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
|
use cudarc::driver::{CudaFunction, CudaSlice, DeviceRepr, DeviceSlice, LaunchAsync, LaunchConfig};
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -243,6 +243,72 @@ enum CudaStorageSlice {
|
|||||||
F64(CudaSlice<f64>),
|
F64(CudaSlice<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait Map1 {
|
||||||
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
fn map(&self, s: &CudaStorageSlice, d: &CudaDevice, l: &Layout) -> Result<CudaStorageSlice> {
|
||||||
|
let out = match s {
|
||||||
|
CudaStorageSlice::U32(s) => CudaStorageSlice::U32(self.f(s, d, l)?),
|
||||||
|
CudaStorageSlice::BF16(s) => CudaStorageSlice::BF16(self.f(s, d, l)?),
|
||||||
|
CudaStorageSlice::F16(s) => CudaStorageSlice::F16(self.f(s, d, l)?),
|
||||||
|
CudaStorageSlice::F32(s) => CudaStorageSlice::F32(self.f(s, d, l)?),
|
||||||
|
CudaStorageSlice::F64(s) => CudaStorageSlice::F64(self.f(s, d, l)?),
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Clone;
|
||||||
|
impl Map1 for Clone {
|
||||||
|
fn f<T: DeviceRepr>(
|
||||||
|
&self,
|
||||||
|
s: &CudaSlice<T>,
|
||||||
|
_: &CudaDevice,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
Ok(s.try_clone()?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Affine(f64, f64);
|
||||||
|
|
||||||
|
impl Map1 for Affine {
|
||||||
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let kernel_name = format!("affine_{}", T::DTYPE.as_str());
|
||||||
|
let func = dev.get_or_load_func(&kernel_name, kernels::AFFINE)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(el) }?;
|
||||||
|
let params = (
|
||||||
|
el,
|
||||||
|
dims.len(),
|
||||||
|
&ds,
|
||||||
|
src,
|
||||||
|
&out,
|
||||||
|
T::from_f64(self.0),
|
||||||
|
T::from_f64(self.1),
|
||||||
|
);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn slice_src_and_dst<'a, T>(
|
fn slice_src_and_dst<'a, T>(
|
||||||
src: &'a CudaSlice<T>,
|
src: &'a CudaSlice<T>,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
@ -332,14 +398,8 @@ fn gemm_config<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl CudaStorage {
|
impl CudaStorage {
|
||||||
pub fn try_clone(&self) -> Result<Self> {
|
pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
|
||||||
let slice = match &self.slice {
|
let slice = Clone.map(&self.slice, self.device(), layout)?;
|
||||||
CudaStorageSlice::U32(slice) => CudaStorageSlice::U32(slice.try_clone()?),
|
|
||||||
CudaStorageSlice::BF16(slice) => CudaStorageSlice::BF16(slice.try_clone()?),
|
|
||||||
CudaStorageSlice::F16(slice) => CudaStorageSlice::F16(slice.try_clone()?),
|
|
||||||
CudaStorageSlice::F32(slice) => CudaStorageSlice::F32(slice.try_clone()?),
|
|
||||||
CudaStorageSlice::F64(slice) => CudaStorageSlice::F64(slice.try_clone()?),
|
|
||||||
};
|
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
@ -420,81 +480,8 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
let shape = layout.shape();
|
let device = self.device().clone();
|
||||||
let dims = shape.dims();
|
let slice = Affine(mul, add).map(&self.slice, &device, layout)?;
|
||||||
let el_count = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
|
||||||
let dev = self.device();
|
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
|
||||||
let slice = match &self.slice {
|
|
||||||
CudaStorageSlice::U32(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<u32>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out, mul as u32, add as u32);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::U32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::BF16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<bf16>(el_count) }?;
|
|
||||||
let params = (
|
|
||||||
el_count,
|
|
||||||
dims.len(),
|
|
||||||
&ds,
|
|
||||||
arg,
|
|
||||||
&out,
|
|
||||||
bf16::from_f64(mul),
|
|
||||||
bf16::from_f64(add),
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F16(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f16>(el_count) }?;
|
|
||||||
let params = (
|
|
||||||
el_count,
|
|
||||||
dims.len(),
|
|
||||||
&ds,
|
|
||||||
arg,
|
|
||||||
&out,
|
|
||||||
f16::from_f64(mul),
|
|
||||||
f16::from_f64(add),
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F32(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f32>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out, mul as f32, add as f32);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F64(arg) => {
|
|
||||||
let arg = &arg.slice(layout.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f64>(el_count) }?;
|
|
||||||
let params = (el_count, dims.len(), &ds, arg, &out, mul, add);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ impl CudaDevice {
|
|||||||
pub struct CudaStorage;
|
pub struct CudaStorage;
|
||||||
|
|
||||||
impl CudaStorage {
|
impl CudaStorage {
|
||||||
pub fn try_clone(&self) -> Result<Self> {
|
pub fn try_clone(&self, _: &Layout) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9,11 +9,11 @@ pub enum Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Storage {
|
impl Storage {
|
||||||
pub fn try_clone(&self) -> Result<Self> {
|
pub fn try_clone(&self, layout: &Layout) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
|
Self::Cpu(storage) => Ok(Self::Cpu(storage.clone())),
|
||||||
Self::Cuda(storage) => {
|
Self::Cuda(storage) => {
|
||||||
let storage = storage.try_clone()?;
|
let storage = storage.try_clone(layout)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -709,7 +709,7 @@ impl Tensor {
|
|||||||
pub fn copy(&self) -> Result<Tensor> {
|
pub fn copy(&self) -> Result<Tensor> {
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(self.storage.try_clone()?),
|
storage: Arc::new(self.storage.try_clone(self.layout())?),
|
||||||
layout: self.layout.clone(),
|
layout: self.layout.clone(),
|
||||||
op: None, // TODO
|
op: None, // TODO
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
|
Reference in New Issue
Block a user