From: Georgi Gerganov Date: Mon, 10 Apr 2023 19:39:24 +0000 (+0300) Subject: whisper : sync with whisper.cpp X-Git-Tag: upstream/0.0.1642~1558 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=40be18a978bec9f0796e74370781f4395541d67f;p=pkg%2Fggml%2Fsources%2Fggml whisper : sync with whisper.cpp --- diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index 71f7a96c..24b9e5d8 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -1511,8 +1511,7 @@ static bool whisper_encode_internal( Vcur, n_state/n_head, n_head, n_ctx), 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head) - ); + ggml_new_tensor_3d(ctx0, wctx.wtype, n_ctx, n_state/n_head, n_head)); struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); #else @@ -1736,10 +1735,12 @@ static bool whisper_encode_internal( wstate.use_buf(ctx0, -1); - //struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - //struct ggml_tensor * v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - struct ggml_tensor* k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); - struct ggml_tensor* v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx)); + Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + + struct ggml_tensor * k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, + ( n_ctx)*ggml_element_size(wstate.kv_cross.v), + (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); @@ -1874,20 +1875,24 @@ static bool whisper_decode_internal( Kcur = ggml_scale(ctx0, Kcur, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, - layer.attn_v_w, - cur); - - Vcur = ggml_add(ctx0, - ggml_repeat(ctx0, - layer.attn_v_b, - Vcur), - Vcur); - // store key and value to memory { + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, + ggml_repeat(ctx0, + layer.attn_v_b, + Vcur), + Vcur); + + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, N)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_state, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + n_past*ggml_element_size(kv_self.v)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); @@ -1926,16 +1931,14 @@ static bool whisper_decode_internal( struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - struct ggml_tensor * V_trans = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), - n_state/n_head, n_head, n_past + N), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head)); + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_past + N, n_state/n_head, n_head, + n_ctx*ggml_element_size(kv_self.v), + n_ctx*ggml_element_size(kv_self.v)*n_state/n_head, + il*n_ctx*ggml_element_size(kv_self.v)*n_state); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -1998,15 +2001,22 @@ static bool whisper_decode_internal( ggml_view_1d(ctx0, wstate.kv_cross.k, M*n_state, il*M*ggml_element_size(wstate.kv_cross.k)*n_state), n_state/n_head, n_head, M); - struct ggml_tensor * Vcross = - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), - n_state/n_head, n_head, M); + //struct ggml_tensor * Vcross = + // ggml_reshape_3d(ctx0, + // ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state), + // n_state/n_head, n_head, M); - struct ggml_tensor * V_trans = - ggml_cpy(ctx0, - ggml_permute(ctx0, Vcross, 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head)); + //struct ggml_tensor * V_trans = + // ggml_cpy(ctx0, + // ggml_permute(ctx0, Vcross, 1, 2, 0, 3), + // ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head)); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, wstate.kv_cross.v, + M, n_state/n_head, n_head, + M*ggml_element_size(wstate.kv_cross.v), + M*ggml_element_size(wstate.kv_cross.v)*n_state/n_head, + il*M*ggml_element_size(wstate.kv_cross.v)*n_state); // ------ @@ -2033,7 +2043,7 @@ static bool whisper_decode_internal( struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);