]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Vectorize load instructions in dmmv f16 CUDA kernel (llama/9816)
authoragray3 <redacted>
Mon, 14 Oct 2024 00:49:08 +0000 (01:49 +0100)
committerGeorgi Gerganov <redacted>
Wed, 16 Oct 2024 08:28:39 +0000 (11:28 +0300)
* Vectorize load instructions in dmmv f16 CUDA kernel

Replaces scalar with vector load instructions, which substantially
improves performance on NVIDIA HBM GPUs, e.g. gives a 1.27X overall
speedup for Meta-Llama-3-8B-Instruct-F16 BS1 inference evaluation on
H100 SXM 80GB HBM3. On GDDR GPUs, there is a slight (1.01X) speedup.

* addressed comment

* Update ggml/src/ggml-cuda/dmmv.cu

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
src/ggml-cuda/dmmv.cu

index 96a5adef5b2b50f05e3ae6c67ab5fb166ce5fb0b..00e21b5d77e3cf9989e99874d78c167a13cb4e1c 100644 (file)
@@ -416,10 +416,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
 
 static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const half * x = (const half *) vx;
-
+    // load 2 halfs into register in a single instruction
+    const half2 x_reg = *((half2 *) &(x[ib + iqs]));
     // automatic half -> float type cast if dfloat == float
-    v.x = x[ib + iqs + 0];
-    v.y = x[ib + iqs + 1];
+    v.x = __low2float(x_reg);
+    v.y = __high2float(x_reg);
 }
 
 static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
@@ -476,13 +477,28 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
             // matrix multiplication
             // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
 #ifdef GGML_CUDA_F16
-            tmp += __hmul2(v, {
-                y[iybs + iqs + j/qr + 0],
-                y[iybs + iqs + j/qr + y_offset]
-            });
+            if ( y_offset == 1 ) {
+                // load 2 dfloats into register in a single instruction
+                const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
+                tmp += __hmul2(v, y_reg);
+            }
+            else {
+                tmp += __hmul2(v, {
+                        y[iybs + iqs + j/qr + 0],
+                        y[iybs + iqs + j/qr + y_offset]
+                    });
+            }
 #else
-            tmp += v.x * y[iybs + iqs + j/qr + 0];
-            tmp += v.y * y[iybs + iqs + j/qr + y_offset];
+            if ( y_offset == 1 ) {
+                // load 2 dfloats into register in a single instruction
+                const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
+                tmp += v.x * y_reg.x;
+                tmp += v.y * y_reg.y;
+            }
+            else {
+                tmp += v.x * y[iybs + iqs + j/qr + 0];
+                tmp += v.y * y[iybs + iqs + j/qr + y_offset];
+            }
 #endif // GGML_CUDA_F16
         }
     }