]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : fix const ptrs warning causing ROCm build issues (#3913)
authorGeorgi Gerganov <redacted>
Thu, 2 Nov 2023 18:32:11 +0000 (20:32 +0200)
committerGitHub <redacted>
Thu, 2 Nov 2023 18:32:11 +0000 (20:32 +0200)
ggml-cuda.cu

index 58b58f3315428618d7d32c7389735162adf73c8e..06c28f5651b72f9b0b0ff9f386220468e1830f07 100644 (file)
@@ -7248,7 +7248,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
 
 __global__ void k_compute_batched_ptrs(
         const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
-        void ** ptrs,
+        const void ** ptrs_src, void ** ptrs_dst,
         int ne12, int ne13,
         int ne23,
         int nb02, int nb03,
@@ -7265,9 +7265,9 @@ __global__ void k_compute_batched_ptrs(
     int i03 = i13 / r3;
     int i02 = i12 / r2;
 
-    ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02   + i03*nb03;
-    ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
-    ptrs[2*ne23 + i12 + i13*ne12] = (char *)     dst_f16 + i12* nb2/2 + i13* nb3/2;
+    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02   + i03*nb03;
+    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)     dst_f16 + i12* nb2/2 + i13* nb3/2;
 }
 
 static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7372,14 +7372,20 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
     } else {
         // use cublasGemmBatchedEx
         const int ne23 = ne12*ne13;
-        // allocate device memory for pointers
-        size_t ptrs_s = 0;
-        void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream);
+
+        const void ** ptrs_src = nullptr;
+              void ** ptrs_dst = nullptr;
+
+        size_t ptrs_src_s = 0;
+        size_t ptrs_dst_s = 0;
+
+        ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
+        ptrs_dst = (      void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
 
         dim3 block_dims(ne13, ne12);
         k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
                 src0_as_f16, src1_as_f16, dst_f16,
-                ptrs_as,
+                ptrs_src, ptrs_dst,
                 ne12, ne13,
                 ne23,
                 nb02, nb03,
@@ -7390,15 +7396,18 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
         CUBLAS_CHECK(
         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                &alpha_f16, (const void * const *) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
-                            (const void * const *) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
-                &beta_f16,  (      void **       ) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
+                &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
+                            (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
+                &beta_f16,  (      void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
                 ne23,
                 CUBLAS_COMPUTE_16F,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-        // free device memory for pointers
-        if (ptrs_s != 0) {
-            ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream);
+
+        if (ptrs_src_s != 0) {
+            ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
+        }
+        if (ptrs_dst_s != 0) {
+            ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
         }
     }
 #endif