]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: MMQ support for iq4_nl, iq4_xs (llama/8278)
authorJohannes Gäßler <redacted>
Fri, 5 Jul 2024 07:06:31 +0000 (09:06 +0200)
committerGeorgi Gerganov <redacted>
Mon, 8 Jul 2024 10:03:28 +0000 (13:03 +0300)
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/mmq.cu
src/ggml-cuda/mmq.cuh
src/ggml-cuda/template-instances/generate_cu_files.py
src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu [new file with mode: 0644]
src/ggml-cuda/vecdotq.cuh

index 650780fd2e462446e19490df2c1b17dbafe937c2..f24312dd0bc90284687eb3b6156d9a7f679a323e 100644 (file)
@@ -68,7 +68,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
         const int iqs4  = k_KQ %  QI4_0;
         const int shift = k_KQ & (QI8_1/2);
 
-        const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
         const int u = Q_q8[k_KQ_0/WARP_SIZE];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -108,7 +108,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
         const int iqs4  = k_KQ %  QI4_1;
         const int shift = k_KQ & (QI8_1/2);
 
-        const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
         const int u = Q_q8[k_KQ_0/WARP_SIZE];
 
         const int sumi = ggml_cuda_dp4a(v, u, 0);
@@ -153,8 +153,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
         const int iqs8  = k_KQ %  QI8_1;
         const int shift = k_KQ & (QI8_1/2);
 
-        int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
-        const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
+        int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
         v |= (vh <<  4) & 0x00000010; // 0 ->  4
         v |= (vh << 11) & 0x00001000; // 1 -> 12
         v |= (vh << 18) & 0x00100000; // 2 -> 20
@@ -200,8 +200,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
         const int iqs8  = k_KQ %  QI8_1;
         const int shift = k_KQ & (QI8_1/2);
 
-        int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
-        const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
+        int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
+        const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
         v |= (vh <<  4) & 0x00000010; // 0 ->  4
         v |= (vh << 11) & 0x00001000; // 1 -> 12
         v |= (vh << 18) & 0x00100000; // 2 -> 20
@@ -249,7 +249,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
         const int ib  = k_KQ / QI8_0;
         const int iqs = k_KQ % QI8_0;
 
-        const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
+        const int v = get_int_b2(K_q8_0[ib].qs, iqs);
 
         T Q_d;
         if (std::is_same<T, half>::value) {
@@ -408,7 +408,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
 
     const T   d   = x[ib].d;
     const int ql0 = x[ib].qs[iqs];
-    const int qh0 = get_int_from_uint8(x[ib].qh, 0);
+    const int qh0 = get_int_b2(x[ib].qh, 0);
     const int ql  = ((ql0 >> (4*shift)) & 0x0F);
     const int qh  = ((qh0 >> idq) << 4) & 0x10;
     const int q   = (ql | qh) - 16;
@@ -433,7 +433,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
 
     const half2 dm  = x[ib].dm;
     const int   ql0 = x[ib].qs[iqs];
-    const int   qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
+    const int   qh0 = get_int_b4(x[ib].qh, 0);
     const int   ql  = ((ql0 >> (4*shift)) & 0x0F);
     const int   qh  = ((qh0 >> idq) << 4) & 0x10;
     const int   q   = (ql | qh);
index 0308beaccbaa3876fd508b194652ca53f15e6d1d..3af2eec2f4163277d5d1df90b75d1958c6a8686e 100644 (file)
@@ -59,6 +59,12 @@ void ggml_cuda_op_mul_mat_q(
         case GGML_TYPE_Q6_K:
             mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
             break;
+        case GGML_TYPE_IQ4_XS:
+            mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
+            break;
+        case GGML_TYPE_IQ4_NL:
+            mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
+            break;
         default:
             GGML_ASSERT(false);
             break;
@@ -87,6 +93,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         case GGML_TYPE_Q4_K:
         case GGML_TYPE_Q5_K:
         case GGML_TYPE_Q6_K:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ4_NL:
             mmq_supported = true;
             break;
         default:
index fca93618dc754b139a1b6a9e0db5a1e1964caacd..118e34d2847fc8a42237a9eaffd8f057ee202932 100644 (file)
@@ -92,15 +92,17 @@ static constexpr __device__ int get_mmq_y_device() {
 
 static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
     return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
-        type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
-        type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q5_0 :
-        type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q5_1 :
-        type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
-        type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
-        type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
-        type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
-        type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
-        type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
+        type == GGML_TYPE_Q4_1    ? MMQ_DP4A_TXS_Q4_1 :
+        type == GGML_TYPE_Q5_0    ? MMQ_DP4A_TXS_Q5_0 :
+        type == GGML_TYPE_Q5_1    ? MMQ_DP4A_TXS_Q5_1 :
+        type == GGML_TYPE_Q8_0    ? MMQ_DP4A_TXS_Q8_0 :
+        type == GGML_TYPE_Q2_K    ? MMQ_DP4A_TXS_Q2_K :
+        type == GGML_TYPE_Q3_K    ? MMQ_DP4A_TXS_Q3_K :
+        type == GGML_TYPE_Q4_K    ? MMQ_DP4A_TXS_Q4_K :
+        type == GGML_TYPE_Q5_K    ? MMQ_DP4A_TXS_Q5_K :
+        type == GGML_TYPE_Q6_K    ? MMQ_DP4A_TXS_Q6_K :
+        type == GGML_TYPE_IQ4_XS  ? MMQ_DP4A_TXS_Q5_0 :
+        type == GGML_TYPE_IQ4_NL  ? MMQ_DP4A_TXS_Q5_0 :
         tile_x_sizes{0, 0, 0};
 }
 
@@ -128,15 +130,17 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
 
 static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
     return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 :
-        type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 :
-        type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q5_0 :
-        type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q5_1 :
-        type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
-        type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
-        type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
-        type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K :
-        type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K :
-        type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
+        type == GGML_TYPE_Q4_1    ? MMQ_MMA_TILE_X_K_Q4_1 :
+        type == GGML_TYPE_Q5_0    ? MMQ_MMA_TILE_X_K_Q5_0 :
+        type == GGML_TYPE_Q5_1    ? MMQ_MMA_TILE_X_K_Q5_1 :
+        type == GGML_TYPE_Q8_0    ? MMQ_MMA_TILE_X_K_Q8_0 :
+        type == GGML_TYPE_Q2_K    ? MMQ_MMA_TILE_X_K_Q2_K :
+        type == GGML_TYPE_Q3_K    ? MMQ_MMA_TILE_X_K_Q3_K :
+        type == GGML_TYPE_Q4_K    ? MMQ_MMA_TILE_X_K_Q4_K :
+        type == GGML_TYPE_Q5_K    ? MMQ_MMA_TILE_X_K_Q5_K :
+        type == GGML_TYPE_Q6_K    ? MMQ_MMA_TILE_X_K_Q6_K :
+        type == GGML_TYPE_IQ4_XS  ? MMQ_MMA_TILE_X_K_Q5_0 :
+        type == GGML_TYPE_IQ4_NL  ? MMQ_MMA_TILE_X_K_Q5_0 :
         0;
 }
 
@@ -185,9 +189,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
 #else
-        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
 #endif // INT8_MMA_AVAILABLE
     }
 
@@ -348,9 +352,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
 #else
-        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
 #endif // INT8_MMA_AVAILABLE
     }
 
@@ -509,8 +513,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
 
-        const int ql = get_int_from_uint8(bxi->qs, kqsx);
-        const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
+        const int ql = get_int_b2(bxi->qs, kqsx);
+        const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
 
         int qs0 = (ql >>  0)   & 0x0F0F0F0F;
         qs0    |= (qh <<  4)   & 0x00000010;  // 0 ->  4
@@ -674,8 +678,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
 
-        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
-        const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
+        const int ql = get_int_b4(bxi->qs, kqsx);
+        const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
 
         int qs0 = (ql >>  0) & 0x0F0F0F0F;
         qs0    |= (qh <<  4) & 0x00000010; // 0 ->  4
@@ -839,9 +843,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
 #else
-        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_b2(bxi->qs, kqsx);
 #endif // INT8_MMA_AVAILABLE
     }
 
@@ -984,7 +988,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
 
-        const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
+        const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
 
 #pragma unroll
         for (int l = 0; l < QR2_K; ++l) {
@@ -1166,8 +1170,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
 
-        const int x_ql_0 = get_int_from_uint8(bxi->qs,    kqsx);
-        const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
+        const int x_ql_0 = get_int_b2(bxi->qs,    kqsx);
+        const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
 
 #pragma unroll
         for (int l = 0; l < QR3_K; ++l) {
@@ -1225,11 +1229,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
 
         const int ksc_low = ksc % (QI3_K/8);
         const int shift_low = 4 * (ksc / (QI3_K/8));
-        const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+        const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
 
         const int ksc_high = QI3_K/8;
         const int shift_high = 2 * ksc;
-        const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+        const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
 
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
@@ -1393,9 +1397,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
 #else
-        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        x_qs[i*(WARP_SIZE + 1)       + threadIdx.x] = get_int_b4(bxi->qs, kqsx);
 #endif // INT8_MMA_AVAILABLE
     }
 
@@ -1610,11 +1614,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
         const int ky = QR5_K*kqsx;
 
-        const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
+        const int ql = get_int_b4(bxi->qs, kqsx);
         const int ql0 = (ql >> 0) & 0x0F0F0F0F;
         const int ql1 = (ql >> 4) & 0x0F0F0F0F;
 
-        const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
+        const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4));
         const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
         const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
 
@@ -1832,11 +1836,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
         const int ky = QR6_K*kqsx;
 
-        const int ql = get_int_from_uint8(bxi->ql, kqsx);
+        const int ql = get_int_b2(bxi->ql, kqsx);
         const int ql0 = (ql >> 0) & 0x0F0F0F0F;
         const int ql1 = (ql >> 4) & 0x0F0F0F0F;
 
-        const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
+        const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
         const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
         const int qh1 =  (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4))))       & 0x30303030;
 
@@ -1883,9 +1887,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
 
 #ifdef INT8_MMA_AVAILABLE
-        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+        x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
 #else
-        x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
+        x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
 #endif // INT8_MMA_AVAILABLE
     }
 }
@@ -2018,6 +2022,124 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 #endif // INT8_MMA_AVAILABLE
 }
 
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = threadIdx.x / QI4_NL;
+    const int kqsx = threadIdx.x % QI4_NL;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
+
+        const int aux_q4 = get_int_b2(bxi->qs, kqsx);
+        const int2 v = get_int_from_table_16(aux_q4);
+        const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+    const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
+        int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + kbxd] = __half2float(bxi->d);
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + kbxd] = __half2float(bxi->d);
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
+
+#ifdef INT8_MMA_AVAILABLE
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + WARP_SIZE*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // INT8_MMA_AVAILABLE
+
+    const int kbx  = 0;           // threadIdx.x / QI4_XS
+    const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
+        int i = i0 + threadIdx.y;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx;
+
+        const int aux_q4 = get_int_b4(bxi->qs, kqsx);
+        const int2 v = get_int_from_table_16(aux_q4);
+        const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
+#ifdef INT8_MMA_AVAILABLE
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 0] = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q5_0 + k0 + 4] = v.y;
+#else
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
+        x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
+#endif // INT8_MMA_AVAILABLE
+    }
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
+        int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
+
+        const float d = __half2float(bxi->d);
+
+        const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
+            | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
+
+#ifdef INT8_MMA_AVAILABLE
+        x_df[i*MMQ_MMA_TILE_X_K_Q5_0 + threadIdx.x % 8] = d * (ls - 32);
+#else
+        x_df[i*(WARP_SIZE/4) + i/4   + threadIdx.x % 8] = d * (ls - 32);
+#endif // INT8_MMA_AVAILABLE
+    }
+}
+
 template<int mmq_x, int mmq_y, int nwarps, bool need_check>
 static __device__ __forceinline__ void mmq_write_back_dp4a(
     const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
@@ -2167,6 +2289,22 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
 };
 
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
+    static constexpr int              vdr          = VDR_IQ4_NL_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_nl<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
+template <int mmq_x, int mmq_y, int nwarps, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
+    static constexpr int              vdr          = VDR_IQ4_XS_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
+};
+
 static bool mmq_need_sum(const ggml_type type_x) {
     switch (type_x) {
         case GGML_TYPE_Q4_0:
@@ -2184,6 +2322,8 @@ static bool mmq_need_sum(const ggml_type type_x) {
         case GGML_TYPE_Q5_K:
             return true;
         case GGML_TYPE_Q6_K:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ4_NL:
             return false;
         default:
             GGML_ASSERT(false);
@@ -2608,6 +2748,8 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
 
 // -------------------------------------------------------------------------------------------------------------------------
 
index ea58d09680231f3ce1482452492af9fdb5567c6d..ffeb3c27df2f43bc390385a41e90bf71cbca2c48 100755 (executable)
@@ -22,7 +22,8 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}
 
 TYPES_MMQ = [
     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
-    "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K"
+    "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
+    "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
 ]
 
 SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
diff --git a/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu b/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
new file mode 100644 (file)
index 0000000..eb02fab
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
diff --git a/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu b/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu
new file mode 100644 (file)
index 0000000..1eb3b74
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
index 3f3a87c750b211e0f09bf28c741014ab3d8331a0..1d510484ac4742931fdb04daaabd2bfbea8a25c5 100644 (file)
@@ -1,36 +1,8 @@
 #include "common.cuh"
 #include <cstdint>
 
-static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
-    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
-
-    int x32 = 0;
-    x32 |= x16[0] <<  0;
-    x32 |= x16[1] << 16;
-
-    return x32;
-}
-
-static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
-    const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
-
-    int x32 = 0;
-    x32 |= x16[0] <<  0;
-    x32 |= x16[1] << 16;
-
-    return x32;
-}
-
-static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
-    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
-}
-
-static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
-    return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
-}
-
 static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
-    const uint16_t * x16 = (const uint16_t *) x;
+    const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
 
     int x32  = x16[2*i32 + 0] <<  0;
     x32     |= x16[2*i32 + 1] << 16;
@@ -768,6 +740,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
 }
 
 #define VDR_IQ2_XXS_Q8_1_MMVQ 2
+#define VDR_IQ2_XXS_Q8_1_MMQ  2
 
 static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -802,6 +775,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
 }
 
 #define VDR_IQ2_XS_Q8_1_MMVQ 2
+#define VDR_IQ2_XS_Q8_1_MMQ  2
 
 static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -840,6 +814,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
 }
 
 #define VDR_IQ2_S_Q8_1_MMVQ 2
+#define VDR_IQ2_S_Q8_1_MMQ  2
 
 static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -887,6 +862,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
 }
 
 #define VDR_IQ3_XXS_Q8_1_MMVQ 2
+#define VDR_IQ3_XXS_Q8_1_MMQ  2
 
 static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -921,6 +897,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
 }
 
 #define VDR_IQ3_S_Q8_1_MMVQ 2
+#define VDR_IQ3_S_Q8_1_MMQ  2
 
 // TODO: don't use lookup table for signs
 static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
@@ -962,6 +939,9 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
     return d * sumi;
 }
 
+#define VDR_IQ1_S_Q8_1_MMVQ 1
+#define VDR_IQ1_S_Q8_1_MMQ  1
+
 static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
     const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
@@ -992,6 +972,9 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
     return d1q * (ds.x*sumi + ds.y*delta);
 }
 
+#define VDR_IQ1_M_Q8_1_MMVQ 1
+#define VDR_IQ1_M_Q8_1_MMQ  1
+
 static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
 
@@ -1051,6 +1034,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
 }
 
 #define VDR_IQ4_NL_Q8_1_MMVQ 2
+#define VDR_IQ4_NL_Q8_1_MMQ  4
 
 static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
@@ -1074,6 +1058,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
 }
 
 #define VDR_IQ4_XS_Q8_1_MMVQ 4
+#define VDR_IQ4_XS_Q8_1_MMQ  4
 
 static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
     const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {