From: Georgi Gerganov Date: Sun, 14 May 2023 11:45:13 +0000 (+0300) Subject: ggml : fix multi-threaded ggml_compute_forward_diag_mask_f32() X-Git-Tag: upstream/0.0.1642~1475 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=0fb280133ac1dfc3ff7c7e6a5dce2a584af8f753;p=pkg%2Fggml%2Fsources%2Fggml ggml : fix multi-threaded ggml_compute_forward_diag_mask_f32() --- diff --git a/src/ggml.c b/src/ggml.c index 63da6799..58727bb3 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -10372,22 +10372,34 @@ static void ggml_compute_forward_diag_mask_f32( assert(src1->type == GGML_TYPE_I32); assert(ggml_nelements(src1) == 2); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + const int n_past = ((int32_t *) src1->data)[0]; + const bool inplace = (bool)((int32_t *) src1->data)[1]; + + if (params->type == GGML_TASK_INIT) { + // TODO: this hack is not good, need a better way to handle this + if (!inplace) { + // use the init task to copy src -> dst + struct ggml_compute_params params_cpy = *params; + + params_cpy.ith = 0; + params_cpy.nth = 1; + params_cpy.type = GGML_TASK_COMPUTE; + + ggml_compute_forward_dup_same_cont(¶ms_cpy, src0, dst); + } + + return; + } + + if (params->type == GGML_TASK_FINALIZE) { return; } const int ith = params->ith; const int nth = params->nth; - const int n_past = ((int32_t *) src1->data)[0]; - const bool inplace = (bool)((int32_t *) src1->data)[1]; - assert(n_past >= 0); - if (!inplace) { - ggml_compute_forward_dup_same_cont(params, src0, dst); - } - // TODO: handle transposed/permuted matrices const int n = ggml_nrows(src0); @@ -10474,7 +10486,7 @@ static void ggml_compute_forward_soft_max_f32( for (int i1 = ir0; i1 < ir1; i1++) { float *sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float *dp = (float *)((char *) dst->data + i1*dst->nb[1]); + float *dp = (float *)((char *) dst->data + i1*dst->nb[1]); #ifndef NDEBUG for (int i = 0; i < nc; ++i) {