bool res = true;
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
- res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
+ res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens);
return res;
}
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
bool res = true;
- res &= pos->ne[0] == params.ubatch.n_tokens;
+ res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd;
return res;
}