From 68f525f3219640750fcc4d3b84686bbfc0a0b8fa Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 21 Jun 2023 10:34:51 +0100 Subject: [PATCH] Move more bits to the backend part. --- src/cpu_backend.rs | 32 ++++++++++++++++++++++++++++++-- src/device.rs | 30 ++---------------------------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 03068866..01c17245 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -68,7 +68,7 @@ impl CpuStorage { // same if it helps. // https://github.com/ggerganov/llama.cpp/blob/aacdbd40562684665b6f7b8ba6695b7a2088bbb0/ggml.c#L7895 match (self, rhs) { - (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { + (Self::F32(lhs), Self::F32(rhs)) => { let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); let data = lhs_index @@ -77,7 +77,7 @@ impl CpuStorage { .collect(); Ok(Self::F32(data)) } - (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { + (Self::F64(lhs), Self::F64(rhs)) => { let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); let data = lhs_index @@ -96,4 +96,32 @@ impl CpuStorage { } } } + + pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { + let elem_count = shape.elem_count(); + match dtype { + DType::F32 => { + let data = vec![1f32; elem_count]; + Self::F32(data) + } + DType::F64 => { + let data = vec![1f64; elem_count]; + Self::F64(data) + } + } + } + + pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self { + let elem_count = shape.elem_count(); + match dtype { + DType::F32 => { + let data = vec![0f32; elem_count]; + Self::F32(data) + } + DType::F64 => { + let data = vec![0f64; elem_count]; + Self::F64(data) + } + } + } } diff --git a/src/device.rs b/src/device.rs index 3677cfff..af538c6c 100644 --- a/src/device.rs +++ b/src/device.rs @@ -56,20 +56,7 @@ impl NdArray for &[[S; N]; impl Device { pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Storage { match self { - Device::Cpu => { - let elem_count = shape.elem_count(); - let storage = match dtype { - DType::F32 => { - let data = vec![1f32; elem_count]; - CpuStorage::F32(data) - } - DType::F64 => { - let data = vec![1f64; elem_count]; - CpuStorage::F64(data) - } - }; - Storage::Cpu(storage) - } + Device::Cpu => Storage::Cpu(CpuStorage::ones_impl(shape, dtype)), Device::Cuda { gpu_id: _ } => { todo!() } @@ -78,20 +65,7 @@ impl Device { pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Storage { match self { - Device::Cpu => { - let elem_count = shape.elem_count(); - let storage = match dtype { - DType::F32 => { - let data = vec![0f32; elem_count]; - CpuStorage::F32(data) - } - DType::F64 => { - let data = vec![0f64; elem_count]; - CpuStorage::F64(data) - } - }; - Storage::Cpu(storage) - } + Device::Cpu => Storage::Cpu(CpuStorage::zeros_impl(shape, dtype)), Device::Cuda { gpu_id: _ } => { todo!() }