// find max logit and calculate mean
float max = cur_p->data[0].logit;
float logits_sum = 0;
+ size_t valid_count = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
- if (cur_p->data[i].logit > max) {
- max = cur_p->data[i].logit;
+ // Only count non-negative infinity values
+ if (cur_p->data[i].logit != -INFINITY) {
+ if (cur_p->data[i].logit > max) {
+ max = cur_p->data[i].logit;
+ }
+ logits_sum += cur_p->data[i].logit;
+ valid_count++;
}
- logits_sum += cur_p->data[i].logit;
}
- float mean = logits_sum/cur_p->size;
+ float mean = valid_count > 0 ? logits_sum/valid_count : 0;
// calculate standard deviation
float acc = 0;
for (size_t i = 0; i < cur_p->size; ++i) {
- acc += pow(cur_p->data[i].logit - mean, 2);
+ // Skip -infinity in std calculation
+ if (cur_p->data[i].logit != -INFINITY) {
+ acc += pow(cur_p->data[i].logit - mean, 2);
+ }
}
- float std = sqrt(acc/cur_p->size);
+ float std = valid_count > 0 ? sqrt(acc/valid_count) : 0;
//apply mask
for (size_t i = 0; i < cur_p->size; ++i) {