diff --git a/candle-examples/examples/replit-code/README.md b/candle-examples/examples/replit-code/README.md new file mode 100644 index 00000000..84ed4c1c --- /dev/null +++ b/candle-examples/examples/replit-code/README.md @@ -0,0 +1,45 @@ +# candle-replit-code: code completion specialized model. + +[replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b) is a +language model specialized for code completion. This model uses 3.3B parameters +in `bfloat16` (so the GPU version will only work on recent nvidia cards). + +## Running some example + +```bash +cargo run --example replit-code --release -- --prompt 'def fibonacci(n): ' +``` +This produces the following output which actually doesn't generate the fibonacci +series properly. + +``` +def fibonacci(n): # write Fibonacci series up to n + """Print a Fibonacci series up to n.""" + + assert type(n) == int, "n must be an integer" + + if (type(fib_list)==None or len==0 ): + fib_list = [1] + + for i in range((len-2)): # start at 2nd element of list and go until end. + n += 1 + + print("Fibonacci number",n,"is:",i) + +def main(): + """Call the functions.""" + + userInput=input('Enter a positive integer: ') + + fibonacci(userInput) + + + + + + + +if __name__ == '__main__': # only run if this file is called directly. + print("This program prints out Fibonacci numbers.") + main() +``` diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs index 97429b7b..87b7d216 100644 --- a/candle-examples/examples/replit-code/main.rs +++ b/candle-examples/examples/replit-code/main.rs @@ -139,7 +139,7 @@ struct Args { seed: u64, /// The length of the sample to generate (in tokens). - #[arg(long, short = 'n', default_value_t = 100)] + #[arg(long, short = 'n', default_value_t = 1000)] sample_len: usize, #[arg(long)] diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index f382a4bb..c1efe16f 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -103,23 +103,25 @@ impl GroupedQueryAttention { (k, v) } }; - let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?; - let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?; + self.kv_cache = Some((key.clone(), value.clone())); + let query = query.contiguous()?; + let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?; + let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?; let attn_weights = (query.matmul(&key)? * self.softmax_scale)?; let attn_bias = { let s_q = query.dim(D::Minus2)?; let s_k = key.dim(D::Minus1)?; let (_, _, a_q, a_k) = self.attn_bias.dims4()?; - self.attn_bias - .narrow(2, a_q - s_q, s_q)? - .narrow(3, a_k - s_k, s_k)? + let start_q = a_q.saturating_sub(s_q); + let start_k = a_k.saturating_sub(s_k); + self.attn_bias.i((.., .., start_q.., start_k..))? }; - let attn_weights = (attn_weights + attn_bias)?; + let attn_weights = attn_weights.broadcast_add(&attn_bias)?; let attn_weights = match mask { None => attn_weights, Some(mask) => masked_fill( &attn_weights, - &mask.broadcast_left(b_size * self.n_heads)?, + &mask.broadcast_as(attn_weights.shape())?, f32::NEG_INFINITY, )?, }; @@ -128,7 +130,8 @@ impl GroupedQueryAttention { .matmul(&value)? .transpose(1, 2)? .flatten_from(D::Minus2)?; - attn_output.apply(&self.out_proj) + let out = attn_output.apply(&self.out_proj)?; + Ok(out) } } @@ -199,7 +202,7 @@ impl MPTBlock { let xs = self.attn.forward(&xs, mask)?; let xs = (xs + residual)?; let residual = &xs; - let xs = xs.apply(&self.norm2)?.apply(&self.ffn); + let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?; xs + residual } } @@ -275,12 +278,15 @@ impl Model { Some(get_mask(seq_len, xs.device())?) }; for block in self.blocks.iter_mut() { - xs = block.forward(&xs, mask.as_ref())? + xs = block.forward(&xs, mask.as_ref())?; } - xs.narrow(1, seq_len - 1, 1)? + let xs = xs.apply(&self.norm_f)?; + let logits = xs + .narrow(1, seq_len - 1, 1)? .squeeze(1)? .matmul(&self.wte.embeddings().t()?)? - .squeeze(1) + .squeeze(1)?; + Ok(logits) } }