diff --git a/examples/llama/main.rs b/examples/llama/main.rs index 1db15816..e51beefd 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -220,10 +220,8 @@ impl Mlp { fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { let shape = mask.shape(); - let _on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?; - // TODO: add an equivalent to where (or xla's select) so that we can use the following: - // let m = mask.where_cond(&on_true, on_false)?; - let m = on_false.clone(); + let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?; + let m = mask.where_cond(&on_true, on_false)?; Ok(m) } @@ -297,7 +295,7 @@ impl CausalSelfAttention { //let mask = Tensor::new(1u32, &device)? // .broadcast_as(&[t, t])? // .lower_triangle()? - let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?; + let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = att.softmax(att.rank() - 1)?; // Convert to contiguous as matmul doesn't support strided vs for now. diff --git a/src/backprop.rs b/src/backprop.rs index 072a9005..ec6f4b59 100644 --- a/src/backprop.rs +++ b/src/backprop.rs @@ -24,6 +24,15 @@ impl Tensor { nodes } else if let Some(op) = node.op() { match op { + Op::WhereCond(t1, t2, t3) => { + let (tg, nodes) = walk(t1, nodes, already_seen); + track_grad |= tg; + let (tg, nodes) = walk(t2, nodes, already_seen); + track_grad |= tg; + let (tg, nodes) = walk(t3, nodes, already_seen); + track_grad |= tg; + nodes + } Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) | Op::Sub(lhs, rhs) @@ -161,6 +170,9 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?; } + Op::WhereCond(_pred, _t, _f) => { + return Err(Error::BackwardNotSupported { op: "where_cond" }) + } Op::Embedding(_lhs, _rhs) => { return Err(Error::BackwardNotSupported { op: "embedding" }) } diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index a2112a30..4571985a 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -13,6 +13,36 @@ pub enum CpuStorage { F64(Vec), } +fn wcond( + pred: &[u32], + shape: &Shape, + stride: &[usize], + t: &[T], + stride_t: &[usize], + f: &[T], + stride_f: &[usize], +) -> Vec { + if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f) + { + let elem_count = shape.elem_count(); + let pred = &pred[..elem_count]; + let t = &t[..elem_count]; + let f = &f[..elem_count]; + pred.iter() + .zip(t.iter().zip(f.iter())) + .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) + .collect::>() + } else { + let dims = shape.dims(); + let it_p = StridedIndex::new(dims, stride); + let it_t = StridedIndex::new(dims, stride_t); + let it_f = StridedIndex::new(dims, stride_f); + it_p.zip(it_t.zip(it_f)) + .map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] }) + .collect::>() + } +} + fn unary_map U>( shape: &Shape, stride: &[usize], @@ -402,6 +432,38 @@ impl CpuStorage { Ok(()) } + pub(crate) fn where_cond( + &self, + shape: &Shape, + stride: &[usize], + t: &Self, + stride_t: &[usize], + f: &Self, + stride_f: &[usize], + ) -> Result { + // TODO: Support types that could be casted to a boolean. + let pred = self.as_slice::()?; + match (t, f) { + (Self::F32(t), Self::F32(f)) => { + let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + Ok(Self::F32(data)) + } + (Self::F64(t), Self::F64(f)) => { + let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + Ok(Self::F64(data)) + } + (Self::U32(t), Self::U32(f)) => { + let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + Ok(Self::U32(data)) + } + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: t.dtype(), + rhs: f.dtype(), + op: "where_cond", + }), + } + } + pub(crate) fn embedding_impl( &self, shape: &Shape, @@ -410,25 +472,20 @@ impl CpuStorage { hidden_size: usize, vocab_size: usize, ) -> Result { - match self { - CpuStorage::U32(ids) => match vs { - CpuStorage::F32(vs) => { - let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; - Ok(CpuStorage::F32(storage)) - } - CpuStorage::F64(vs) => { - let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; - Ok(CpuStorage::F64(storage)) - } - CpuStorage::U32(vs) => { - let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; - Ok(CpuStorage::U32(storage)) - } - }, - ids => Err(Error::UnexpectedDType { - expected: DType::U32, - got: ids.dtype(), - }), + let ids = self.as_slice::()?; + match vs { + CpuStorage::F32(vs) => { + let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; + Ok(CpuStorage::F32(storage)) + } + CpuStorage::F64(vs) => { + let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; + Ok(CpuStorage::F64(storage)) + } + CpuStorage::U32(vs) => { + let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; + Ok(CpuStorage::U32(storage)) + } } } diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 61125f93..f0ceea6a 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -411,6 +411,18 @@ impl CudaStorage { } } + pub(crate) fn where_cond( + &self, + _shape: &Shape, + _stride: &[usize], + _t: &Self, + _stride_t: &[usize], + _f: &Self, + _stride_f: &[usize], + ) -> Result { + Err(CudaError::InternalError("TODO: implement where_cond")) + } + pub(crate) fn embedding_impl( &self, _shape: &Shape, diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index da7221e4..babc6e7d 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -90,6 +90,18 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn where_cond( + &self, + _: &Shape, + _: &[usize], + _: &Self, + _: &[usize], + _: &Self, + _: &[usize], + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn embedding_impl( &self, _: &Shape, diff --git a/src/op.rs b/src/op.rs index 1d93eee8..da096b5c 100644 --- a/src/op.rs +++ b/src/op.rs @@ -12,6 +12,7 @@ pub(crate) enum Op { BroadcastDiv(Tensor, Tensor), Matmul(Tensor, Tensor), Embedding(Tensor, Tensor), + WhereCond(Tensor, Tensor, Tensor), Cat(Vec, usize), diff --git a/src/storage.rs b/src/storage.rs index 16f74995..e44a2db6 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -154,6 +154,35 @@ impl Storage { } } + pub(crate) fn where_cond( + &self, + shape: &Shape, + stride: &[usize], + t: &Self, + stride_t: &[usize], + f: &Self, + stride_f: &[usize], + ) -> Result { + self.same_device(t, "where")?; + self.same_device(f, "where")?; + t.same_dtype(f, "where")?; + match (self, t, f) { + (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => { + let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + Ok(Self::Cpu(storage)) + } + (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => { + let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + Ok(Self::Cuda(storage)) + } + (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "embedding", + }), + } + } + pub(crate) fn embedding_impl( &self, shape: &Shape, diff --git a/src/tensor.rs b/src/tensor.rs index 32347328..2cf24a06 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -425,6 +425,29 @@ impl Tensor { Ok(from_storage(storage, c_shape, op, false)) } + pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result { + let _shap = self.same_shape_binary_op(on_true, "where_cond")?; + let shape = self.same_shape_binary_op(on_false, "where_cond")?; + let storage = self.storage.where_cond( + shape, + self.stride(), + &on_true.storage, + on_true.stride(), + &on_false.storage, + on_false.stride(), + )?; + let op = if self.track_op() || on_true.track_op() || on_false.track_op() { + Some(Op::WhereCond( + self.clone(), + on_true.clone(), + on_false.clone(), + )) + } else { + None + }; + Ok(from_storage(storage, shape, op, false)) + } + pub fn embedding(ids: &Self, rhs: &Self) -> Result { if !rhs.is_contiguous() { return Err(Error::RequiresContiguous { op: "embedding" });