]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda : fix dequantize kernel names (llama/4938)
authorGeorgi Gerganov <redacted>
Mon, 15 Jan 2024 11:27:00 +0000 (13:27 +0200)
committerGeorgi Gerganov <redacted>
Wed, 17 Jan 2024 19:21:09 +0000 (21:21 +0200)
ggml-cuda.cu

index a870718a745cbeee66b3df1006337f8969974ceb..c3e14bc96ec38fe89eef4571d8a4559cc231a523 100644 (file)
@@ -6309,14 +6309,14 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
 }
 
 template<typename dst_t>
-static void dequantize_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb32 = k / 32;
     const int nb = (k + 255) / 256;
     dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
 }
 
 template<typename dst_t>
-static void dequantize_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
     const int nb32 = k / 32;
     const int nb = (k + 255) / 256;
     dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
@@ -6370,9 +6370,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
     int id;
     switch (type) {
         case GGML_TYPE_Q4_0:
-            return dequantize_q4_0_cuda;
+            return dequantize_row_q4_0_cuda;
         case GGML_TYPE_Q4_1:
-            return dequantize_q4_1_cuda;
+            return dequantize_row_q4_1_cuda;
         case GGML_TYPE_Q5_0:
             return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
         case GGML_TYPE_Q5_1:
@@ -6407,9 +6407,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
 static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_Q4_0:
-            return dequantize_q4_0_cuda;
+            return dequantize_row_q4_0_cuda;
         case GGML_TYPE_Q4_1:
-            return dequantize_q4_1_cuda;
+            return dequantize_row_q4_1_cuda;
         case GGML_TYPE_Q5_0:
             return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
         case GGML_TYPE_Q5_1: