]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X (llama/8800)
authorslaren <redacted>
Thu, 1 Aug 2024 13:26:22 +0000 (15:26 +0200)
committerGeorgi Gerganov <redacted>
Thu, 8 Aug 2024 10:45:29 +0000 (13:45 +0300)
* cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X

* update asserts

* only use dmmv for supported types

* add test

src/ggml-cuda.cu
src/ggml-cuda/dmmv.cu
src/ggml-cuda/dmmv.cuh
tests/test-backend-ops.cpp

index c73ae40d49da61d76939ae70f1a8e230c962c996..b510777fb78f6fe7d582327fcbeeb047fe8cc04d 100644 (file)
@@ -1885,10 +1885,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
 
-    bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
+    bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
-        && src1->ne[1] == 1;
+        && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
     bool          use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
         && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
index d7a2a2513bd3eef8f870e0a0be469b375c6f51b5..96a5adef5b2b50f05e3ae6c67ab5fb166ce5fb0b 100644 (file)
@@ -500,7 +500,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
 }
 
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
     const dim3 block_nums(block_num_y, 1, 1);
@@ -510,7 +510,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
 }
 
 static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -519,7 +519,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
 }
 
 static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -528,7 +528,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
 }
 
 static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -537,7 +537,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
 }
 
 static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -588,7 +588,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
 }
 
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
-    GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+    GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0);
     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
     const dim3 block_nums(block_num_y, 1, 1);
     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@@ -672,3 +672,12 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
     GGML_UNUSED(src1_ncols);
     GGML_UNUSED(src1_padded_row_size);
 }
+
+bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
+    return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
+        src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
+        src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
+        src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
+        src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
+        src0_type == GGML_TYPE_F16;
+}
index 4c5ebd475fdb59ec7f521857c43545fdc7d3b542..e727eb97f6aada6c0a74581421389bd56a8a8d75 100644 (file)
@@ -16,3 +16,5 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);
index 2fa59fd0aa2e80c9eea36888cf6d49040006fef4..5de70d5540ebe19b59ba4837a2c5b3f1c27020a8 100644 (file)
@@ -804,8 +804,7 @@ struct test_cpy : public test_case {
 
     test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 10, 1},
-            std::array<int64_t, 4> permute = {0, 0, 0, 0},
-            bool _dst_use_permute = false)
+            std::array<int64_t, 4> permute = {0, 0, 0, 0})
         : type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
           _src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
 
@@ -2269,6 +2268,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     for (ggml_type type_a : other_types) {
         for (ggml_type type_b : {GGML_TYPE_F32}) {
+
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), { 1,  1}, {1, 1}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
         }
     }