mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Fix the rebase.
This commit is contained in:
@ -681,7 +681,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
|
||||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||||
("usin", DType::F16) => contiguous::sin::HALF,
|
("usin", DType::F16) => contiguous::sin::HALF,
|
||||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||||
@ -698,7 +697,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uround", DType::F16) => contiguous::round::HALF,
|
("uround", DType::F16) => contiguous::round::HALF,
|
||||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
|
||||||
(name, dtype) => {
|
(name, dtype) => {
|
||||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||||
}
|
}
|
||||||
@ -730,7 +728,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uceil", DType::F32) => strided::ceil::FLOAT,
|
("uceil", DType::F32) => strided::ceil::FLOAT,
|
||||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||||
("uround", DType::F32) => strided::round::FLOAT,
|
("uround", DType::F32) => strided::round::FLOAT,
|
||||||
("uabs", DType::F32) => strided::abs::FLOAT,
|
|
||||||
("ucos", DType::F16) => strided::cos::HALF,
|
("ucos", DType::F16) => strided::cos::HALF,
|
||||||
("usin", DType::F16) => strided::sin::HALF,
|
("usin", DType::F16) => strided::sin::HALF,
|
||||||
("usqr", DType::F16) => strided::sqr::HALF,
|
("usqr", DType::F16) => strided::sqr::HALF,
|
||||||
@ -745,7 +742,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uceil", DType::F16) => strided::ceil::HALF,
|
("uceil", DType::F16) => strided::ceil::HALF,
|
||||||
("ufloor", DType::F16) => strided::floor::HALF,
|
("ufloor", DType::F16) => strided::floor::HALF,
|
||||||
("uround", DType::F16) => strided::round::HALF,
|
("uround", DType::F16) => strided::round::HALF,
|
||||||
("uabs", DType::F16) => strided::abs::HALF,
|
|
||||||
(name, dtype) => {
|
(name, dtype) => {
|
||||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||||
}
|
}
|
||||||
|
@ -370,7 +370,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
for (_, tensor) in model.tensor_infos.iter() {
|
for (_, tensor) in model.tensor_infos.iter() {
|
||||||
let elem_count = tensor.shape.elem_count();
|
let elem_count = tensor.shape.elem_count();
|
||||||
total_size_in_bytes +=
|
total_size_in_bytes +=
|
||||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
|
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||||
}
|
}
|
||||||
println!(
|
println!(
|
||||||
"loaded {:?} tensors ({}) in {:.2}s",
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
@ -387,7 +387,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
for (_, tensor) in model.tensors.iter() {
|
for (_, tensor) in model.tensors.iter() {
|
||||||
let elem_count = tensor.shape().elem_count();
|
let elem_count = tensor.shape().elem_count();
|
||||||
total_size_in_bytes +=
|
total_size_in_bytes +=
|
||||||
elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
|
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size();
|
||||||
}
|
}
|
||||||
println!(
|
println!(
|
||||||
"loaded {:?} tensors ({}) in {:.2}s",
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
|
@ -179,7 +179,7 @@ macro_rules! ops{
|
|||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(
|
ops!(
|
||||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
|
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, round, erf, gelu_erf, tanh,
|
||||||
recip, abs
|
recip
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
|
@ -110,7 +110,6 @@ UNARY_OP(gelu_erf)
|
|||||||
UNARY_OP(erf)
|
UNARY_OP(erf)
|
||||||
UNARY_OP(tanh)
|
UNARY_OP(tanh)
|
||||||
UNARY_OP(recip)
|
UNARY_OP(recip)
|
||||||
UNARY_OP(abs)
|
|
||||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||||
@ -129,6 +128,7 @@ BFLOAT_UNARY_OP(neg)
|
|||||||
BFLOAT_UNARY_OP(exp)
|
BFLOAT_UNARY_OP(exp)
|
||||||
BFLOAT_UNARY_OP(log)
|
BFLOAT_UNARY_OP(log)
|
||||||
BFLOAT_UNARY_OP(gelu)
|
BFLOAT_UNARY_OP(gelu)
|
||||||
|
BFLOAT_UNARY_OP(abs)
|
||||||
BFLOAT_UNARY_OP(ceil)
|
BFLOAT_UNARY_OP(ceil)
|
||||||
BFLOAT_UNARY_OP(floor)
|
BFLOAT_UNARY_OP(floor)
|
||||||
BFLOAT_UNARY_OP(round)
|
BFLOAT_UNARY_OP(round)
|
||||||
@ -136,7 +136,6 @@ BFLOAT_UNARY_OP(gelu_erf)
|
|||||||
BFLOAT_UNARY_OP(erf)
|
BFLOAT_UNARY_OP(erf)
|
||||||
BFLOAT_UNARY_OP(tanh)
|
BFLOAT_UNARY_OP(tanh)
|
||||||
BFLOAT_UNARY_OP(recip)
|
BFLOAT_UNARY_OP(recip)
|
||||||
BFLOAT_UNARY_OP(abs)
|
|
||||||
|
|
||||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||||
#endif
|
#endif
|
||||||
|
Reference in New Issue
Block a user