mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Cleanup fixed a few ops removed debugging scaffolding.
This commit is contained in:

committed by
Nicolas Patry

parent
7cfffcac10
commit
2813fb5dbc
@ -105,8 +105,6 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
// TODO Is this necessary
|
||||
// self.buffer.synchronize();
|
||||
match self.dtype {
|
||||
DType::U8 => Ok(CpuStorage::U8(
|
||||
self.buffer.read_to_vec(self.buffer.length() as usize / 1),
|
||||
@ -140,6 +138,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
|
||||
assert!(layout.is_contiguous());
|
||||
assert!(layout.start_offset() == 0);
|
||||
assert_eq!(dtype, DType::F32);
|
||||
|
||||
let mut buffer = device.new_buffer(el, self.dtype);
|
||||
@ -173,10 +172,10 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
// debug!("TODO reduce_op {op:?} {sum_dims:?}");
|
||||
assert!(sum_dims.len() == 1);
|
||||
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
||||
assert!(layout.is_contiguous());
|
||||
assert!(layout.start_offset() == 0);
|
||||
let device = self.device.clone();
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
@ -269,13 +268,6 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// command_buffer.wait_until_scheduled();
|
||||
// debug!(
|
||||
// "cast {:?} - {:?} - {:?}",
|
||||
// dtype,
|
||||
// self.buffer.length(),
|
||||
// buffer.length()
|
||||
// );
|
||||
Ok(Self {
|
||||
buffer,
|
||||
device: device.clone(),
|
||||
@ -290,7 +282,7 @@ impl BackendStorage for MetalStorage {
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
@ -300,6 +292,7 @@ impl BackendStorage for MetalStorage {
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
@ -337,7 +330,9 @@ impl BackendStorage for MetalStorage {
|
||||
let el_count = shape.elem_count();
|
||||
let mut buffer = device.new_buffer(el_count, dtype);
|
||||
let command_buffer = device.command_queue.new_command_buffer();
|
||||
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
|
||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||
{
|
||||
use candle_metal_kernels::binary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
@ -380,10 +375,10 @@ impl BackendStorage for MetalStorage {
|
||||
lhs_l.dims(),
|
||||
&self.buffer,
|
||||
&lhs_l.stride(),
|
||||
lhs_l.start_offset(),
|
||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&rhs.buffer,
|
||||
&rhs_l.stride(),
|
||||
rhs_l.start_offset(),
|
||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -420,11 +415,14 @@ impl BackendStorage for MetalStorage {
|
||||
"where_u8_f32",
|
||||
&dims,
|
||||
&self.buffer,
|
||||
(layout.stride(), layout.start_offset()),
|
||||
(
|
||||
layout.stride(),
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(&t_l.stride(), t_l.start_offset()),
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(&f_l.stride(), f_l.start_offset()),
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -511,7 +509,9 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
assert!(src_l.is_contiguous());
|
||||
assert!(src_l.start_offset() == 0);
|
||||
assert!(ids_l.is_contiguous());
|
||||
assert!(ids_l.start_offset() == 0);
|
||||
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let ids_el = ids_l.shape().elem_count();
|
||||
@ -681,6 +681,7 @@ impl BackendStorage for MetalStorage {
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let el_count = src_shape.elem_count();
|
||||
// todo!("COPY STRIDED {src_shape:?} {el_count} {src_l:?} {dst_offset}");
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
@ -699,15 +700,13 @@ impl BackendStorage for MetalStorage {
|
||||
src_l.dims(),
|
||||
&self.buffer,
|
||||
&src_l.stride(),
|
||||
src_l.start_offset(),
|
||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||
&mut dst.buffer,
|
||||
dst_offset,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
// todo!("Output {:?}", dst.buffer.read_to_vec::<f32>(10));
|
||||
// }
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -732,24 +731,11 @@ impl BackendDevice for MetalDevice {
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
// let capture = metal::CaptureManager::shared();
|
||||
// let descriptor = metal::CaptureDescriptor::new();
|
||||
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||
// descriptor.set_capture_device(&device);
|
||||
// let mut dir = std::env::current_dir()?;
|
||||
// dir.push("out.gputrace");
|
||||
// descriptor.set_output_url(dir);
|
||||
|
||||
// capture
|
||||
// .start_capture(&descriptor)
|
||||
// .map_err(MetalError::from)?;
|
||||
let command_queue = device.new_command_queue();
|
||||
// let command_buffer = _command_queue.new_owned_command_buffer();
|
||||
let kernels = Arc::new(Kernels::new());
|
||||
Ok(Self {
|
||||
device,
|
||||
command_queue,
|
||||
// command_buffer,
|
||||
kernels,
|
||||
})
|
||||
}
|
||||
@ -819,9 +805,6 @@ impl BackendDevice for MetalDevice {
|
||||
option,
|
||||
),
|
||||
};
|
||||
// TODO is that necessary ?
|
||||
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
// debug!("Allocate 2 - buffer size {}", buffer.length());
|
||||
Ok(Self::Storage {
|
||||
buffer,
|
||||
device: self.clone(),
|
||||
|
@ -157,8 +157,6 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
) -> Tensor {
|
||||
let dtype = storage.dtype();
|
||||
let device = storage.device();
|
||||
let shape = shape.into();
|
||||
// println!("{:?} {storage:?}", shape);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(RwLock::new(storage)),
|
||||
@ -168,11 +166,7 @@ pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
let result = Tensor(Arc::new(tensor_));
|
||||
// todo!(" from_storage");
|
||||
// let result = result.to_device(&Device::Cpu).unwrap();
|
||||
// todo!(" {result}");
|
||||
result
|
||||
Tensor(Arc::new(tensor_))
|
||||
}
|
||||
|
||||
impl Tensor {
|
||||
@ -1869,7 +1863,10 @@ impl Tensor {
|
||||
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
|
||||
(Storage::Metal(storage), Device::Cpu) => {
|
||||
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
|
||||
Storage::Cpu(storage.to_cpu_storage()?)
|
||||
}
|
||||
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
|
||||
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
|
||||
// are the same.
|
||||
|
@ -329,18 +329,14 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
println!("{tokens:?}");
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..1 {
|
||||
for index in 0.. {
|
||||
if tokens.len() >= config.seq_len {
|
||||
break;
|
||||
}
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
// println!("Input {}", input);
|
||||
// println!("Input {}", input.to_device(&candle::Device::Cpu)?);
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
|
@ -150,7 +150,7 @@ macro_rules! ops{
|
||||
}
|
||||
|
||||
pub mod unary {
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, copy);
|
||||
ops!(cos, sin, exp, sqr, sqrt, neg, copy, log);
|
||||
}
|
||||
pub mod binary {
|
||||
ops!(add, sub, mul, div);
|
||||
|
@ -63,6 +63,7 @@ UNARY_OP(sqr)
|
||||
UNARY_OP(sqrt)
|
||||
UNARY_OP(neg)
|
||||
UNARY_OP(exp)
|
||||
UNARY_OP(log)
|
||||
UNARY(id, float, copy_float, copy_float_strided)
|
||||
UNARY(id, half, copy_half, copy_half_strided)
|
||||
|
||||
@ -73,6 +74,7 @@ BFLOAT_UNARY_OP(sqr)
|
||||
BFLOAT_UNARY_OP(sqrt)
|
||||
BFLOAT_UNARY_OP(neg)
|
||||
BFLOAT_UNARY_OP(exp)
|
||||
BFLOAT_UNARY_OP(log)
|
||||
|
||||
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
|
||||
#endif
|
||||
|
@ -9,7 +9,6 @@ pub struct Embedding {
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
// todo!("Embedding {embeddings}");
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
|
@ -156,7 +156,6 @@ impl CausalSelfAttention {
|
||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||
todo!("X {x1}");
|
||||
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
||||
@ -174,7 +173,6 @@ impl CausalSelfAttention {
|
||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
todo!("X {q}");
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
@ -297,7 +295,6 @@ impl Block {
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
todo!("---X {}", x);
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
@ -330,7 +327,6 @@ impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, _seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
//println!("Embeddings {}", self.wte.embeddings());
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
|
Reference in New Issue
Block a user