mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Factor some code out.
This commit is contained in:
@ -292,6 +292,35 @@ impl Map2 for MatMul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn divide_by_sum_over_dim<T: WithDType + num_traits::NumAssign>(
|
||||||
|
s: &mut [T],
|
||||||
|
shape: &Shape,
|
||||||
|
dim: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
// [self] stores data in a contiguous way starting at offset 0.
|
||||||
|
let dims = shape.dims();
|
||||||
|
let elem_per_slice = dims[dim];
|
||||||
|
let prod_pre_dim = dims[..dim].iter().product();
|
||||||
|
let prod_post_dim = dims[dim + 1..].iter().product();
|
||||||
|
for pre_idx in 0..prod_pre_dim {
|
||||||
|
for post_idx in 0..prod_post_dim {
|
||||||
|
let mut sum = 0f64;
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
sum += s[idx].to_f64();
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
let sum = T::from_f64(sum);
|
||||||
|
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
||||||
|
for _ in 0..elem_per_slice {
|
||||||
|
s[idx] /= sum;
|
||||||
|
idx += prod_post_dim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
impl CpuStorage {
|
impl CpuStorage {
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
@ -437,85 +466,13 @@ impl CpuStorage {
|
|||||||
|
|
||||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||||
// [self] stores data in a contiguous way starting at offset 0.
|
// [self] stores data in a contiguous way starting at offset 0.
|
||||||
let dims = shape.dims();
|
|
||||||
let elem_per_slice = dims[dim];
|
|
||||||
let prod_pre_dim = dims[..dim].iter().product();
|
|
||||||
let prod_post_dim = dims[dim + 1..].iter().product();
|
|
||||||
match self {
|
match self {
|
||||||
Self::BF16(storage) => {
|
Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||||
for pre_idx in 0..prod_pre_dim {
|
Self::F16(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||||
for post_idx in 0..prod_post_dim {
|
Self::F32(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||||
let mut sum = 0f64;
|
Self::F64(s) => divide_by_sum_over_dim(s, shape, dim),
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
Self::U32(_) => Ok(()),
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
sum += storage[idx].to_f64();
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
let sum = bf16::from_f64(sum);
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
storage[idx] /= sum;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Self::F16(storage) => {
|
|
||||||
for pre_idx in 0..prod_pre_dim {
|
|
||||||
for post_idx in 0..prod_post_dim {
|
|
||||||
let mut sum = 0f64;
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
sum += storage[idx].to_f64();
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
let sum = f16::from_f64(sum);
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
storage[idx] /= sum;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Self::F32(storage) => {
|
|
||||||
for pre_idx in 0..prod_pre_dim {
|
|
||||||
for post_idx in 0..prod_post_dim {
|
|
||||||
let mut sum = 0f64;
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
sum += storage[idx] as f64;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
let sum = sum as f32;
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
storage[idx] /= sum;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Self::F64(storage) => {
|
|
||||||
for pre_idx in 0..prod_pre_dim {
|
|
||||||
for post_idx in 0..prod_post_dim {
|
|
||||||
let mut sum = 0f64;
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
sum += storage[idx];
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx;
|
|
||||||
for _ in 0..elem_per_slice {
|
|
||||||
storage[idx] /= sum;
|
|
||||||
idx += prod_post_dim
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Self::U32(_) => {}
|
|
||||||
}
|
}
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
|
@ -35,6 +35,7 @@ pub trait WithDType: Sized + Copy {
|
|||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
fn from_f64(v: f64) -> Self;
|
fn from_f64(v: f64) -> Self;
|
||||||
|
fn to_f64(self) -> f64;
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||||
|
|
||||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||||
@ -46,7 +47,7 @@ pub trait WithDType: Sized + Copy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! with_dtype {
|
macro_rules! with_dtype {
|
||||||
($ty:ty, $dtype:ident, $from_f64:expr) => {
|
($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => {
|
||||||
impl WithDType for $ty {
|
impl WithDType for $ty {
|
||||||
const DTYPE: DType = DType::$dtype;
|
const DTYPE: DType = DType::$dtype;
|
||||||
|
|
||||||
@ -54,6 +55,10 @@ macro_rules! with_dtype {
|
|||||||
$from_f64(v)
|
$from_f64(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn to_f64(self) -> f64 {
|
||||||
|
$to_f64(self)
|
||||||
|
}
|
||||||
|
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||||
CpuStorage::$dtype(data)
|
CpuStorage::$dtype(data)
|
||||||
}
|
}
|
||||||
@ -82,8 +87,10 @@ macro_rules! with_dtype {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
with_dtype!(u32, U32, |v: f64| v as u32);
|
use half::{bf16, f16};
|
||||||
with_dtype!(half::f16, F16, half::f16::from_f64);
|
|
||||||
with_dtype!(half::bf16, BF16, half::bf16::from_f64);
|
with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
|
||||||
with_dtype!(f32, F32, |v: f64| v as f32);
|
with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
|
||||||
with_dtype!(f64, F64, |v: f64| v);
|
with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
||||||
|
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||||
|
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
||||||
|
Reference in New Issue
Block a user