]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: initial support for IQ4_XS quantization (llama/11501)
authorRémy O <redacted>
Thu, 6 Feb 2025 06:09:59 +0000 (07:09 +0100)
committerGeorgi Gerganov <redacted>
Thu, 27 Feb 2025 06:55:36 +0000 (08:55 +0200)
13 files changed:
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp
ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/types.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 2e1bcf691b3b0ce259345e95aa6b956f8664371c..1c99ebe2e2cabbecdfe9fef1d0b28f4aa0da2bb0 100644 (file)
@@ -1622,6 +1622,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
         //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
         //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
+        //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs)
         CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
 #undef CREATE_FA
 
@@ -1655,6 +1656,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f16,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f16,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f16,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f16,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 
         CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
@@ -1673,6 +1675,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f16,   , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f16,   , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f16,  , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f16,  , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
 #undef CREATE_MM
 #undef CREATE_MM2
@@ -1726,6 +1729,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+            CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         } else {
             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -1744,6 +1748,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+            CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         }
 
@@ -1770,6 +1775,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+                CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             } else {
                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -1788,6 +1794,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+                CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
                 CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             }
         }
@@ -1837,6 +1844,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc,   matmul_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc,   matmul_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc,  matmul_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc,  matmul_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
         // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
@@ -1861,6 +1869,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc,   matmul_id_iq2_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         }
 #undef CREATE_MM2
@@ -1902,6 +1911,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
         // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
@@ -1926,6 +1936,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         }
 #undef CREATE_MM
@@ -1962,6 +1973,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i],   "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1),   mul_mat_vec_iq2_s_f32_f32_len,   mul_mat_vec_iq2_s_f32_f32_data,   "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i],   "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1),   mul_mat_vec_iq3_s_f32_f32_len,   mul_mat_vec_iq3_s_f32_f32_data,   "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i],  "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1),  mul_mat_vec_iq4_xs_f32_f32_len,  mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i],  "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1),  mul_mat_vec_iq4_nl_f32_f32_len,  mul_mat_vec_iq4_nl_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
 
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1),  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -1981,6 +1993,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i],   "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1),   mul_mat_vec_iq2_s_f16_f32_len,   mul_mat_vec_iq2_s_f16_f32_data,   "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i],   "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1),   mul_mat_vec_iq3_s_f16_f32_len,   mul_mat_vec_iq3_s_f16_f32_data,   "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i],  "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1),  mul_mat_vec_iq4_xs_f16_f32_len,  mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i],  "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1),  mul_mat_vec_iq4_nl_f16_f32_len,  mul_mat_vec_iq4_nl_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
     }
 
@@ -2001,6 +2014,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S],   "mul_mat_vec_id_iq2_s_f32",   mul_mat_vec_id_iq2_s_f32_len,   mul_mat_vec_id_iq2_s_f32_data,   "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S],   "mul_mat_vec_id_iq3_s_f32",   mul_mat_vec_id_iq3_s_f32_len,   mul_mat_vec_id_iq3_s_f32_data,   "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS],  "mul_mat_vec_id_iq4_xs_f32",  mul_mat_vec_id_iq4_xs_f32_len,  mul_mat_vec_id_iq4_xs_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL],  "mul_mat_vec_id_iq4_nl_f32",  mul_mat_vec_id_iq4_nl_f32_len,  mul_mat_vec_id_iq4_nl_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
 
     // dequant shaders
@@ -2020,6 +2034,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S],   "dequant_iq2_s",   dequant_iq2_s_len,   dequant_iq2_s_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S],   "dequant_iq3_s",   dequant_iq3_s_len,   dequant_iq3_s_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS],  "dequant_iq4_xs",  dequant_iq4_xs_len,  dequant_iq4_xs_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL],  "dequant_iq4_nl",  dequant_iq4_nl_len,  dequant_iq4_nl_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 
     // get_rows
@@ -2035,6 +2050,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S],   "get_rows_iq2_s",   get_rows_iq2_s_len,   get_rows_iq2_s_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S],   "get_rows_iq3_s",   get_rows_iq3_s_len,   get_rows_iq3_s_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS],  "get_rows_iq4_xs",  get_rows_iq4_xs_len,  get_rows_iq4_xs_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl",  get_rows_iq4_nl_len,  get_rows_iq4_nl_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32",  get_rows_f32_f32_len,  get_rows_f32_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2049,6 +2065,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S],   "get_rows_iq2_s_f32",   get_rows_iq2_s_f32_len,   get_rows_iq2_s_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S],   "get_rows_iq3_s_f32",   get_rows_iq3_s_f32_len,   get_rows_iq3_s_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS],  "get_rows_iq4_xs_f32",  get_rows_iq4_xs_f32_len,  get_rows_iq4_xs_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl_f32",  get_rows_iq4_nl_f32_len,  get_rows_iq4_nl_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
@@ -2995,6 +3012,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
         case GGML_TYPE_IQ2_S:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
             break;
         default:
@@ -3048,6 +3066,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
         case GGML_TYPE_IQ2_S:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
             break;
         default:
@@ -3084,6 +3103,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
         case GGML_TYPE_IQ2_S:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
             break;
         default:
@@ -3132,6 +3152,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
         case GGML_TYPE_IQ2_S:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
             break;
         default:
@@ -3163,6 +3184,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
         case GGML_TYPE_IQ2_S:
         case GGML_TYPE_IQ3_XXS:
         case GGML_TYPE_IQ3_S:
+        case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
             break;
         default:
@@ -8037,6 +8059,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     case GGML_TYPE_IQ2_S:
                     case GGML_TYPE_IQ3_XXS:
                     case GGML_TYPE_IQ3_S:
+                    case GGML_TYPE_IQ4_XS:
                     case GGML_TYPE_IQ4_NL:
                         break;
                     default:
@@ -8110,6 +8133,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 //case GGML_TYPE_IQ2_S:
                 //case GGML_TYPE_IQ3_XXS:
                 //case GGML_TYPE_IQ3_S:
+                //case GGML_TYPE_IQ4_XS:
                 case GGML_TYPE_IQ4_NL:
                     break;
                 default:
@@ -8132,6 +8156,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     case GGML_TYPE_IQ2_S:
                     case GGML_TYPE_IQ3_XXS:
                     case GGML_TYPE_IQ3_S:
+                    case GGML_TYPE_IQ4_XS:
                     case GGML_TYPE_IQ4_NL:
                         return true;
                     default:
index aeae5400dfcb8752a35c9edaa2d27371d8f1d13a..9c9fe9626dbba7766214201ae8daf5743ec1cc82 100644 (file)
@@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
 #endif
 
 void main() {
-#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
+#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
     init_iq_shmem(gl_WorkGroupSize);
     if (gl_LocalInvocationIndex.x != 0) {
         return;
index d4b068e61866a451c7861b7b7be1d3893c4e1c53..660811086d613db36c492c79c8eb275cad42e8f6 100644 (file)
@@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx)
 #endif
 
 void main() {
-#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
+#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
     init_iq_shmem(gl_WorkGroupSize);
     if (gl_LocalInvocationIndex.x != 0) {
         return;
index ee68775317b8339cabf538a119c0fef693bd4ba7..ecfdbfaa88cec42c32677cb9aa07e610e3a78f9b 100644 (file)
@@ -304,6 +304,42 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
 }
 #endif
 
+#if defined(DATA_A_IQ4_XS)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+    const uint ib32 = iqs / 32;
+    const uint iq = 16 * ib32 + (iqs % 16);
+
+    const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
+    const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
+    const uint qshift = (iqs & 16) >> 2;
+    u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]);
+    qs = (qs >> qshift) & uint8_t(0xF);
+
+    const float dl = float(int(sl | (sh << 4)) - 32);
+    return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+    const uint ib32 = iqs / 32;
+    const uint iq = 16 * ib32 + (iqs % 16);
+
+    const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
+    const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
+    const uint qshift = (iqs & 16) >> 2;
+    u8vec4 qs = u8vec4(
+        data_a[a_offset + ib].qs[iq + 0],
+        data_a[a_offset + ib].qs[iq + 1],
+        data_a[a_offset + ib].qs[iq + 2],
+        data_a[a_offset + ib].qs[iq + 3]
+    );
+    qs = (qs >> qshift) & uint8_t(0xF);
+
+    const float dl = float(int(sl | (sh << 4)) - 32);
+    return dl * vec4(
+        kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y],
+        kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
+}
+#endif
+
 #if defined(DATA_A_IQ4_NL)
 vec2 dequantize(uint ib, uint iqs, uint a_offset) {
     const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -321,7 +357,7 @@ vec2 get_dm(uint ib, uint a_offset) {
 }
 #endif
 
-#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
+#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
 vec2 get_dm(uint ib, uint a_offset) {
     return vec2(float(data_a[a_offset + ib].d), 0);
 }
index 974efd3f9a693a59cd126f9eb440d00588c5dc18..78c3bddf227440678f91fd7e18ee42ab3dd63a4d 100644 (file)
@@ -454,6 +454,27 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords
 }
 #endif
 
+#if defined(DATA_A_IQ4_XS)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
+   block_iq4_xs block;
+};
+
+float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+    const float16_t d = bl.block.d;
+    const uint idx = coordInBlock[1];
+
+    const uint ib32 = (idx & 0xE0) >> 5; // 0..7
+
+    const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
+    const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
+    const uint qshift = (idx & 16) >> 2;
+    const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
+
+    float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
+    return ret;
+}
+#endif
 
 #if defined(DATA_A_IQ4_NL)
 layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
@@ -504,6 +525,8 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
 #define dequantFuncA dequantFuncIQ3_XXS
 #elif defined(DATA_A_IQ3_S)
 #define dequantFuncA dequantFuncIQ3_S
+#elif defined(DATA_A_IQ4_XS)
+#define dequantFuncA dequantFuncIQ4_XS
 #elif defined(DATA_A_IQ4_NL)
 #define dequantFuncA dequantFuncIQ4_NL
 #endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp
new file mode 100644 (file)
index 0000000..f930852
--- /dev/null
@@ -0,0 +1,34 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+    // Each thread handles 1 subblock (1 scale and 32 quantized values)
+    const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
+
+    init_iq_shmem(gl_WorkGroupSize);
+
+    if (ib >= p.nel / 256) {
+        return;
+    }
+
+    const uint ib32 = gl_LocalInvocationID.x % 8;
+
+    const float d = float(data_a[ib].d);
+    // Scales are 6 bits
+    const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF)
+                     | (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4);
+    const float dl = d * (int(scale) - 32);
+
+    const uint b_idx = 256 * ib + 32 * ib32;
+    const uint q_idx = 16 * ib32;
+    [[unroll]] for (uint l = 0; l < 16; ++l) {
+        data_b[b_idx + l +  0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
+        data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >>  4]);
+    }
+}
index 043a53023889c4f8da84cb98461b5164c2783101..ba88ce79a21ae0fad6ebd3ce55a6735eebdcc9f6 100644 (file)
@@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
 #endif
 
 void main() {
-#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
+#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
index 09dc43d8dc3b86d92237fb7cdfc91357a47e8eaa..c16a2a9f605c563519a896bea97a56b0d99e7d9b 100644 (file)
@@ -12,7 +12,7 @@ void main() {
     const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
     const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
 
-#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
+#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
index 48156e7bab6cdd3a2e0c93f6f4c7fdac5636573c..d7e99727db184286273f60bedeae5ab75370effa 100644 (file)
@@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 void main() {
     const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
 
-#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
+#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
index d0559aac8ec9228947c999e991612ce9a171aaff..33b2234e71df0d7f102a4ef4fadce44c30e98dcd 100644 (file)
@@ -95,7 +95,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
 #endif
 
 void main() {
-#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
+#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
@@ -547,6 +547,25 @@ void main() {
             const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
             const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
 
+            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
+            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+#elif defined(DATA_A_IQ4_XS)
+            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
+
+            const uint ib = idx / 128;                  // 2 values per idx
+            const uint ib32 = (idx % 128) / 16;         // 0..7
+            const uint iq = 16 * ib32 + 2 * (idx % 8);
+
+            const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
+            const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
+            const uint qshift = (idx & 8) >> 1;
+            u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
+            qs = (qs >> qshift) & uint8_t(0xF);
+
+            const float d = float(data_a[ib].d);
+            const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
+
             buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
             buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
 #elif defined(DATA_A_IQ4_NL)
index 27c5d68b3d94ef696cdd17d0d91541c4ec6df0fe..7e29bbfec7b33e541f8a54a16d8e83b531978f3b 100644 (file)
@@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
 #endif
 
 void main() {
-#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
+#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
index 9e56a35300bafc8eb402973d316445fb9a9f013a..db643a54c8e78d86dbe100503e3c140dfd7981a6 100644 (file)
@@ -1026,6 +1026,23 @@ void init_iq_shmem(uvec3 wgsize)
 #define A_TYPE_PACKED16 block_iq3_s_packed16
 #endif
 
+#define QUANT_K_IQ4_XS 256
+#define QUANT_R_IQ4_XS 1
+
+struct block_iq4_xs
+{
+    float16_t d;
+    uint16_t scales_h;
+    uint8_t scales_l[QUANT_K_IQ4_XS/64];
+    uint8_t qs[QUANT_K_IQ4_XS/2];
+};
+
+#if defined(DATA_A_IQ4_XS)
+#define QUANT_K QUANT_K_IQ4_XS
+#define QUANT_R QUANT_R_IQ4_XS
+#define A_TYPE block_iq4_xs
+#endif
+
 #define QUANT_K_IQ4_NL 32
 #define QUANT_R_IQ4_NL 2
 
@@ -1042,7 +1059,13 @@ struct block_iq4_nl_packed16
 };
 
 #if defined(DATA_A_IQ4_NL)
+#define QUANT_K QUANT_K_IQ4_NL
+#define QUANT_R QUANT_R_IQ4_NL
+#define A_TYPE block_iq4_nl
+#define A_TYPE_PACKED16 block_iq4_nl_packed16
+#endif
 
+#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
 const int8_t kvalues_iq4nl_const[16] = {
     int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
     int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
@@ -1058,11 +1081,6 @@ void init_iq_shmem(uvec3 wgsize)
     }
     barrier();
 }
-
-#define QUANT_K QUANT_K_IQ4_NL
-#define QUANT_R QUANT_R_IQ4_NL
-#define A_TYPE block_iq4_nl
-#define A_TYPE_PACKED16 block_iq4_nl_packed16
 #endif
 
 #endif // !defined(GGML_TYPES_COMP)
index 93ddbfadc5f9abd15ff26e9449ca077c94457e16..77e7e1148b49d28d118d775d17cfc70bef37e5b6 100644 (file)
@@ -60,6 +60,7 @@ const std::vector<std::string> type_names = {
     "iq2_s",
     "iq3_xxs",
     "iq3_s",
+    "iq4_xs",
     "iq4_nl"
 };