]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : adapt AMX to tensor->grad removal (llama/0)
authorGeorgi Gerganov <redacted>
Sat, 16 Nov 2024 19:38:01 +0000 (21:38 +0200)
committerGeorgi Gerganov <redacted>
Mon, 18 Nov 2024 08:56:51 +0000 (10:56 +0200)
ggml-ci

src/ggml-amx/ggml-amx.cpp

index 37da985399c584707da7f32ce65fcbc67ca9a1e0..8568e7965fd2ec65e34d55452f5656a2f865c7c9 100644 (file)
@@ -317,8 +317,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st
             const enum ggml_type type = src0->type;
             const int64_t ne0 = op->ne[0];
 
-            bool is_training = src0->grad || src1->grad;
-
             // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
             // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
             bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
@@ -326,7 +324,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st
             bool can_use_amx =
                 is_contiguous_2d(src0) &&       // src0 must be contiguous
                 is_contiguous_2d(src1) &&       // src1 must be contiguous
-                !is_training &&                 // inference only
                 src1->type == GGML_TYPE_F32 &&  // src1 must be float32
                 has_amx_kernels &&              // with amx kernel impls
                 ne0 % (TILE_N * 2) == 0;        // out_features is 32x