]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : fix op mul check with command-r-plus (#10476)
authorDiego Devesa <redacted>
Sun, 24 Nov 2024 15:10:26 +0000 (16:10 +0100)
committerGitHub <redacted>
Sun, 24 Nov 2024 15:10:26 +0000 (16:10 +0100)
src/llama.cpp

index 001711037d5d19a12747d9a567fe713cadeab9dd..20df09b133bfb6739da68413ee4686ec058840f4 100644 (file)
@@ -7181,12 +7181,12 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
             } break;
         case GGML_OP_ADD:
             {
-                ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w->ne[0], 512);
+                ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
                 op_tensor = ggml_add(ctx, a, w);
             } break;
         case GGML_OP_MUL:
             {
-                ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w->ne[0], 512);
+                ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
                 op_tensor = ggml_mul(ctx, a, w);
             } break;
         case GGML_OP_DIV: