const int ith = params->ith;
const int nth = params->nth;
- if (ith >= HEADS) {
- return;
- }
-
- const int h_start = (HEADS * ith) / nth;
- const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
- (HEADS * (ith + 1)) / nth : HEADS;
+ const int h_start = (HEADS * (ith )) / nth;
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+ (HEADS * (ith + 1)) / nth : HEADS;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
const int ith = params->ith;
const int nth = params->nth;
- if (ith >= HEADS) {
- return;
- }
-
- const int h_start = (HEADS * ith) / nth;
- const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
- (HEADS * (ith + 1)) / nth : HEADS;
+ const int h_start = (HEADS * (ith )) / nth;
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+ (HEADS * (ith + 1)) / nth : HEADS;
float * k = (float *) dst->src[0]->data;
float * v = (float *) dst->src[1]->data;
const int ith = params->ith;
const int nth = params->nth;
- if (ith >= HEADS) {
- return;
- }
-
- const int h_start = (HEADS * ith) / nth;
- const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
- (HEADS * (ith + 1)) / nth : HEADS;
+ const int h_start = (HEADS * (ith )) / nth;
+ const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+ (HEADS * (ith + 1)) / nth : HEADS;
float * r = (float *) dst->src[0]->data;
float * w = (float *) dst->src[1]->data;