]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : fix some mul mat cases + add tests for src1 F16 (#669)
authorbssrdf <redacted>
Fri, 29 Dec 2023 08:32:31 +0000 (03:32 -0500)
committerGitHub <redacted>
Fri, 29 Dec 2023 08:32:31 +0000 (10:32 +0200)
* fixed mul-mat error for old GPUs

* style fixes

* add mul mat src1 f16 test cases, fix more cases

ggml-ci

---------

Co-authored-by: bssrdf <redacted>
Co-authored-by: slaren <redacted>
src/ggml-backend.c
src/ggml-cuda.cu
src/ggml.c
tests/test-backend-ops.cpp

index 526ce732be5b5d8168dd9405a250fcb0d26c4816..2c3752067515fbcf632ec305203152854fe7503e 100644 (file)
@@ -614,10 +614,14 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
 }
 
 static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    return true;
+    switch (op->op) {
+        case GGML_OP_MUL_MAT:
+            return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
+        default:
+            return true;
+    }
 
     GGML_UNUSED(backend);
-    GGML_UNUSED(op);
 }
 
 static struct ggml_backend_i cpu_backend_i = {
index abad9cc39e2cf051691ed51c0fd49a7434163aa5..9a9effcf58932b7d06f4dfdb1ea3bf0cdeefac80 100644 (file)
@@ -7485,6 +7485,8 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
     const int64_t ne00 = src0->ne[0];
     const int64_t row_diff = row_high - row_low;
 
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
     // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
 #ifdef GGML_CUDA_F16
     cuda_pool_alloc<half> src1_dfloat_a;
@@ -7577,6 +7579,7 @@ static void ggml_cuda_op_mul_mat_cublas(
     const int compute_capability = g_device_caps[id].cc;
 
     if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
+        //printf("this branch\n");
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
         cuda_pool_alloc<half> src0_as_f16;
         if (src0->type != GGML_TYPE_F16) {
@@ -7614,9 +7617,9 @@ static void ggml_cuda_op_mul_mat_cublas(
 
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
         to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
-    }
-    else {
+    } else {
         cuda_pool_alloc<float> src0_ddq_as_f32;
+        cuda_pool_alloc<float> src1_ddq_as_f32;
 
         if (src0->type != GGML_TYPE_F32) {
             const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
@@ -7624,7 +7627,15 @@ static void ggml_cuda_op_mul_mat_cublas(
             src0_ddq_as_f32.alloc(row_diff*ne00);
             to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
         }
+        if (src1->type != GGML_TYPE_F32) {
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
+            GGML_ASSERT(to_fp32_cuda != nullptr);
+            src1_ddq_as_f32.alloc(src1_ncols*ne10);
+            to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
+        }
+
         const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
+        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
 
         const float alpha = 1.0f;
         const float beta = 0.0f;
@@ -7633,9 +7644,9 @@ static void ggml_cuda_op_mul_mat_cublas(
         CUBLAS_CHECK(
             cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                     row_diff, src1_ncols, ne10,
-                    &alpha, src0_ddf_i, ne00,
-                            src1_ddf_i, ne10,
-                    &beta,  dst_dd_i,   ldc));
+                    &alpha, src0_ddf_i,  ne00,
+                            src1_ddf1_i, ne10,
+                    &beta,  dst_dd_i,    ldc));
     }
 
     (void) dst;
@@ -8035,6 +8046,7 @@ static void ggml_cuda_op_mul_mat(
 
     GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
     GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
 
     GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
 
@@ -8481,9 +8493,9 @@ static __global__ void k_compute_batched_ptrs(
     int64_t i03 = i13 / r3;
     int64_t i02 = i12 / r2;
 
-    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02   + i03*nb03;
-    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
-    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2   + i13*nbd3;
+    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
+    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
 }
 
 static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -8492,28 +8504,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
 
     GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const int64_t nb01 = src0->nb[1];
-    const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
-    const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-
-    const int64_t nb11 = src1->nb[1];
-    const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
-    const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    const int64_t ne1 = ggml_nelements(src1);
-    const int64_t ne  = ggml_nelements(dst);
+    const int64_t ne_dst = ggml_nelements(dst);
 
     ggml_cuda_set_device(g_main_device);
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@@ -8522,7 +8516,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
 
     ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
     void * src0_ddq = src0_extra->data_device[g_main_device];
-    half * src0_as_f16 = (half *) src0_ddq;
+    half * src0_f16 = (half *) src0_ddq;
 
     ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
     float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
@@ -8531,11 +8525,15 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
 
     // convert src1 to fp16
-    const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
-    GGML_ASSERT(to_fp16_cuda != nullptr);
-
-    cuda_pool_alloc<half> src1_as_f16(ne1);
-    to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream);
+    cuda_pool_alloc<half> src1_f16_alloc;
+    if (src1->type != GGML_TYPE_F16) {
+        const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+        const int64_t ne_src1 = ggml_nelements(src1);
+        src1_f16_alloc.alloc(ne_src1);
+        GGML_ASSERT(to_fp16_cuda != nullptr);
+        to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
+    }
+    half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
 
     cuda_pool_alloc<half> dst_f16;
     char * dst_t;
@@ -8557,7 +8555,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     const void * beta  = &beta_f16;
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
-        dst_t = (char *) dst_f16.alloc(ne);
+        dst_t = (char *) dst_f16.alloc(ne_dst);
 
         nbd2 /= sizeof(float) / sizeof(half);
         nbd3 /= sizeof(float) / sizeof(half);
@@ -8604,9 +8602,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         CUBLAS_CHECK(
         cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                alpha, (const char *) src0_as_f16,       CUDA_R_16F,   nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
-                       (const char *) src1_as_f16.get(), CUDA_R_16F,   nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
-                beta,  (      char *)       dst_t,       cu_data_type, ne01,                dst->nb[2]/sizeof(float), // strideC
+                alpha, (const char *) src0_f16, CUDA_R_16F,   nb01/nb00, nb02/nb00,  // strideA
+                       (const char *) src1_f16, CUDA_R_16F,   nb11/nb10, nb12/nb10,  // strideB
+                beta,  (      char *)    dst_t, cu_data_type, ne01,       nb2/nb0,   // strideC
                 ne12*ne13,
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -8619,12 +8617,13 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
 
         dim3 block_dims(ne13, ne12);
         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
-                src0_as_f16, src1_as_f16.get(), dst_t,
+                src0_f16, src1_f16, dst_t,
                 ptrs_src.get(), ptrs_dst.get(),
                 ne12, ne13,
                 ne23,
                 nb02, nb03,
-                nb12, nb13,
+                src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
+                src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
                 nbd2, nbd3,
                 r2, r3);
         CUDA_CHECK(cudaGetLastError());
@@ -8632,8 +8631,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         CUBLAS_CHECK(
         cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F,   nb01/sizeof(half),
-                       (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   nb11/sizeof(float),
+                alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F,   nb01/nb00,
+                       (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   nb11/nb10,
                 beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
                 ne23,
                 cu_compute_type,
@@ -8643,7 +8642,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-        to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream);
+        to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
     }
 }
 
@@ -8682,13 +8681,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
     } else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
         // KQV single-batch
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
-    } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
+    } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
         // KQ + KQV multi-batch
         ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
     } else if (src0->type == GGML_TYPE_F32) {
         ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
     } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
-        if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
+        if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) {
 #ifdef GGML_CUDA_FORCE_DMMV
             const bool use_mul_mat_vec_q = false;
 #else
index ed56e60a8893efa2c9e88b627f87728f65dac89c..a9e1ea9b40ec4eb7fd2d69fec1e4ef76255b85fa 100644 (file)
@@ -9687,7 +9687,7 @@ static void ggml_compute_forward_mul_mat(
             const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
             assert(params->wsize >= ne11*ne12*ne13*row_size);
-            assert(src1->type == GGML_TYPE_F32);
+            GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
             for (int64_t i13 = 0; i13 < ne13; ++i13) {
                 for (int64_t i12 = 0; i12 < ne12; ++i12) {
index f3df8a8c62a9a83961ed94234385d39de664b29a..b115299c0ce303740653183123742236ffd98997 100644 (file)
@@ -350,13 +350,18 @@ struct test_case {
         fflush(stdout);
 
         // check if backends support op
+        bool supported = true;
         for (ggml_backend_t backend : {backend1, backend2}) {
             if (!ggml_backend_supports_op(backend, out)) {
-                printf("not supported\n");
-                ggml_free(ctx);
-                return true;
+                printf("not supported [%s] ", ggml_backend_name(backend));
+                supported = false;
             }
         }
+        if (!supported) {
+            printf("\n");
+            ggml_free(ctx);
+            return true;
+        }
 
         // post-graph sentinel
         add_sentinel(ctx);
@@ -1505,8 +1510,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     }
 
     for (ggml_type type_a : all_types) {
-        for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
-            // FIXME: CPU crashes on f16xf16
+        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {1, 1}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {2, 1}));