]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cuda : compute ptrs for cublasGemmBatchedEx in a kernel (#3891)
authorslaren <redacted>
Wed, 1 Nov 2023 22:10:09 +0000 (23:10 +0100)
committerGitHub <redacted>
Wed, 1 Nov 2023 22:10:09 +0000 (23:10 +0100)
* ggml-cuda : compute ptrs for cublasGemmBatchedEx in a kernel

* fix warnings

ggml-cuda.cu

index 12ee10e3d9bdcdbc59ec7c14cfa842da89ee2ccd..61cd1747cac4fc1f917e644031bb01bc80e66f0e 100644 (file)
@@ -6696,8 +6696,10 @@ inline void ggml_cuda_op_clamp(
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const float min = ((float *) dst->op_params)[0];
-    const float max = ((float *) dst->op_params)[1];
+    float min;
+    float max;
+    memcpy(&min, dst->op_params, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 
     clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
     CUDA_CHECK(cudaGetLastError());
@@ -7221,6 +7223,30 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 
+__global__ void k_compute_batched_ptrs(
+        const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
+        void ** ptrs,
+        int ne12, int ne13,
+        int ne23,
+        int nb02, int nb03,
+        int nb12, int nb13,
+        int nb2, int nb3,
+        int r2, int r3) {
+    int i13 = blockIdx.x * blockDim.x + threadIdx.x;
+    int i12 = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (i13 >= ne13 || i12 >= ne12) {
+        return;
+    }
+
+    int i03 = i13 / r3;
+    int i02 = i12 / r2;
+
+    ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02   + i03*nb03;
+    ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
+    ptrs[2*ne23 + i12 + i13*ne12] = (char *)     dst_f16 + i12* nb2/2 + i13* nb3/2;
+}
+
 static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(!ggml_is_transposed(src0));
     GGML_ASSERT(!ggml_is_transposed(src1));
@@ -7322,49 +7348,35 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     } else {
         // use cublasGemmBatchedEx
-        // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
         const int ne23 = ne12*ne13;
 
-        // TODO: avoid this alloc
-        void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
-
-        for (int i13 = 0; i13 < ne13; ++i13) {
-            for (int i12 = 0; i12 < ne12; ++i12) {
-                int i03 = i13 / r3;
-                int i02 = i12 / r2;
-
-                ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3];
-                ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
-                ptrs[2*ne23 + i12 + i13*ne12] = (char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
-            }
-        }
-
-        // allocate device memory for pointers
         void ** ptrs_as = nullptr;
-        CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
-
-        // TODO: this does not work for some reason -- not sure why?
-        //size_t ptrs_s = 0;
-        //ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
-
-        // copy pointers to device
-        CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
-
-        free(ptrs);
+        size_t ptrs_s = 0;
+        ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
+
+        dim3 block_dims(ne13, ne12);
+        k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+                src0_as_f16, src1_as_f16, dst_f16,
+                ptrs_as,
+                ne12, ne13,
+                ne23,
+                nb02, nb03,
+                nb12, nb13,
+                dst->nb[2], dst->nb[3],
+                r2, r3);
+        CUDA_CHECK(cudaGetLastError());
 
         CUBLAS_CHECK(
         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                &alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
-                            (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
-                &beta_f16,  (      void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
+                &alpha_f16, (const void * const *) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
+                            (const void * const *) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
+                &beta_f16,  (      void **       ) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
                 ne23,
                 CUBLAS_COMPUTE_16F,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
-        // free device memory for pointers
-        CUDA_CHECK(cudaFree(ptrs_as));
-        //ggml_cuda_pool_free(ptrs_as, ptrs_s);
+        ggml_cuda_pool_free(ptrs_as, ptrs_s);
     }
 #endif