]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml: allow casting between f32 and i32 (llama/15783)
authorXuan-Son Nguyen <redacted>
Mon, 8 Sep 2025 10:33:01 +0000 (17:33 +0700)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* ggml: allow casting between f32 and i32

* fix cuda

* add vulkan

* fix CPU non-cont

* add non-cont test case

* add note

* extend test number range

* correct note

* add cont version for vulkan

12 files changed:
include/ggml-cpu.h
include/ggml.h
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ops.cpp
src/ggml-cuda/convert.cuh
src/ggml-cuda/cpy.cu
src/ggml-cuda/ggml-cuda.cu
src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp

index 1a78935aa05cf32a1025c7989fbe12089e860f69..9edd485136972d88a3192b1369249d1bc5a2db64 100644 (file)
@@ -134,6 +134,7 @@ extern "C" {
     GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
 
     GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *,       float *, int64_t);
+    GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *,     int32_t *, int64_t);
     GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
     GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
     GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
index c01b98ac78f5a823665fe4fad824a3effa94acac..058f4267f754402d1028f0191fb717e1a8469b82 100644 (file)
@@ -1404,6 +1404,7 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    // note: casting from f32 to i32 will discard the fractional part
     GGML_API struct ggml_tensor * ggml_cast(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 09772e806188c0fc48db059ad0f52bfafd9d9838..c1312908495385fc7fb7c37e12420c9982717b94 100644 (file)
@@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_K,
         .nrows                    = 1,
     },
+    [GGML_TYPE_I32] = {
+        .from_float               = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
+    },
 };
 
 const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@@ -2696,7 +2699,10 @@ struct ggml_cplan ggml_graph_plan(
                         if (ggml_is_quantized(node->type) ||
                             // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
                             (node->src[0]->type == GGML_TYPE_F16  && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
-                            (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
+                            (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) ||
+                            // conversion between F32 and I32
+                            (node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) ||
+                            (node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) {
                             cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                         }
                     } break;
@@ -3258,6 +3264,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {
     }
 }
 
+void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) {
+    int64_t i = 0;
+    for (; i < n; ++i) {
+        y[i] = x[i];
+    }
+}
+
 void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
     int64_t i = 0;
 #if defined(__AVX2__)
index 0bb767e01aed3ba66aec91fcc691a24290f3d08a..9adf910769f8ecafc1a1c270ce9d3d22181beeba 100644 (file)
@@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32(
                         id += ne00 * (ne01 - ir1);
                     }
                 }
+            } else if (dst->type == GGML_TYPE_I32) {
+                size_t id = 0;
+                int32_t * dst_ptr = (int32_t *) dst->data;
+
+                for (int i03 = 0; i03 < ne03; i03++) {
+                    for (int i02 = 0; i02 < ne02; i02++) {
+                        id += ne00 * ir0;
+                        for (int i01 = ir0; i01 < ir1; i01++) {
+                            for (int i00 = 0; i00 < ne00; i00++) {
+                                const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+                                dst_ptr[id] = *src0_ptr;
+                                id++;
+                            }
+                        }
+                        id += ne00 * (ne01 - ir1);
+                    }
+                }
             } else {
                 GGML_ABORT("fatal error"); // TODO: implement
             }
@@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32(
                 }
             }
         }
+    } else if (dst->type == GGML_TYPE_I32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(int32_t *) dst_ptr = *(const float *) src0_ptr;
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    } else {
+        GGML_ABORT("fatal error"); // TODO: implement
+    }
+}
+
+static void ggml_compute_forward_dup_i32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    const int ith = params->ith; // thread index
+    const int nth = params->nth; // number of threads
+
+    // parallelize by rows
+    const int nr = ne01;
+    // number of rows per thread
+    const int dr = (nr + nth - 1) / nth;
+    // row range for this thread
+    const int ir0 = dr * ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    // dst counters
+
+    int64_t i10 = 0;
+    int64_t i11 = 0;
+    int64_t i12 = 0;
+    int64_t i13 = 0;
+
+    // TODO: not optimal, but works
+    if (dst->type == GGML_TYPE_F32) {
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                i10 += ne00 * ir0;
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+                for (int64_t i01 = ir0; i01 < ir1; i01++) {
+                    for (int64_t i00 = 0; i00 < ne00; i00++) {
+                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
+
+                        *(float *) dst_ptr = *(const int32_t *) src0_ptr;
+
+                        if (++i10 == ne0) {
+                            i10 = 0;
+                            if (++i11 == ne1) {
+                                i11 = 0;
+                                if (++i12 == ne2) {
+                                    i12 = 0;
+                                    if (++i13 == ne3) {
+                                        i13 = 0;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                i10 += ne00 * (ne01 - ir1);
+                while (i10 >= ne0) {
+                    i10 -= ne0;
+                    if (++i11 == ne1) {
+                        i11 = 0;
+                        if (++i12 == ne2) {
+                            i12 = 0;
+                            if (++i13 == ne3) {
+                                i13 = 0;
+                            }
+                        }
+                    }
+                }
+            }
+        }
     } else {
         GGML_ABORT("fatal error"); // TODO: implement
     }
@@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup(
             {
                 ggml_compute_forward_dup_f32(params, dst);
             } break;
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_dup_i32(params, dst);
+            } break;
         default:
             {
                 if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
index c62e8a1b1040a771435cfd71c2fd29e5945b7b36..ef9e129950c98f010f113ab4e45427cb55ca343e 100644 (file)
@@ -38,6 +38,8 @@ template<typename dst_t, typename src_t>
         return __float2bfloat16(float(x));
     } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
         return __bfloat162float(x);
+    } else if constexpr(std::is_same_v<dst_t, int32_t>) {
+        return int32_t(x);
     } else {
         return float(x);
     }
index c40db08cedb05e27af6f50fd8df9c1faa64decad..8567c3d5a16b0d19cbe00c6f2dbbd634894ae034 100644 (file)
@@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
         ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
         ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
+        ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
     } else {
         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
index 9f5dcd9341214b53708496f3bc427dba2d21b5c0..a88b9f75ef2982b3835d8346f6d70957b9d1b947 100644 (file)
@@ -3461,6 +3461,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
+                    return true;
+                }
+                if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
                     return true;
                 }
index c1a0a2bef171e5fa169b4d4bbfcabaa259379815..578bdd6ecaa3a447df3c70ebbc0f3d85f1430df6 100644 (file)
@@ -583,6 +583,8 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
     GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
     GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_I32,
+    GGML_METAL_KERNEL_TYPE_CPY_I32_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -1616,6 +1618,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                     cpy_f16_f16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,                    cpy_bf16_f32,                    use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,                   cpy_bf16_bf16,                   use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_I32,                     cpy_f32_i32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_I32_F32,                     cpy_i32_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                    cpy_f32_q8_0,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                    cpy_f32_q4_0,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                    cpy_f32_q4_1,                    true);
@@ -1945,6 +1949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                            case GGML_TYPE_Q5_0:
                            case GGML_TYPE_Q5_1:
                            case GGML_TYPE_IQ4_NL:
+                           case GGML_TYPE_I32:
                                 return true;
                            default:
                                 return false;
@@ -1977,6 +1982,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                             default:
                                 return false;
                         }
+                    case GGML_TYPE_I32:
+                        return op->type == GGML_TYPE_F32;
                     default:
                         return false;
                 };
@@ -5680,6 +5687,7 @@ static int ggml_metal_encode_node(
 
                             switch (dstt) {
                                 case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
+                                case GGML_TYPE_I32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_I32].pipeline; break;
                                 case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
                                 case GGML_TYPE_BF16:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
                                 case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
@@ -5691,6 +5699,13 @@ static int ggml_metal_encode_node(
                                 default: GGML_ABORT("not implemented");
                             };
                         } break;
+                    case GGML_TYPE_I32:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_I32_F32].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
                     case GGML_TYPE_F16:
                         {
                             switch (dstt) {
index 2d56c62674c8e049b491f90c914a9491bd54a9dc..4dc762bf1af4dc31c138c596d0259efbadf938c6 100644 (file)
@@ -5338,6 +5338,8 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
 
 template [[host_name("kernel_cpy_f32_f32")]]   kernel kernel_cpy_t kernel_cpy<float,  float>;
 template [[host_name("kernel_cpy_f32_f16")]]   kernel kernel_cpy_t kernel_cpy<float,  half>;
+template [[host_name("kernel_cpy_f32_i32")]]   kernel kernel_cpy_t kernel_cpy<float,  int32_t>;
+template [[host_name("kernel_cpy_i32_f32")]]   kernel kernel_cpy_t kernel_cpy<int32_t, float>;
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_cpy_f32_bf16")]]  kernel kernel_cpy_t kernel_cpy<float,  bfloat>;
 #endif
index 36951bceb8d6780ff3a7fc53ffcfc6ba4222c81e..e6245d0cfca4fae3dbeed79d08f2e4e145f07c5c 100644 (file)
@@ -506,8 +506,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_pad_f32;
     vk_pipeline pipeline_roll_f32;
     vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
-    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
-    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
+    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
+    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
     vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
     vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
@@ -3226,12 +3226,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     if (device->float_controls_rte_fp16) {
         ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
@@ -5693,6 +5697,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_cpy_f32_bf16;
         }
     }
+    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
+        if (contig) {
+            return ctx->device->pipeline_contig_cpy_f32_i32;
+        } else {
+            return ctx->device->pipeline_cpy_f32_i32;
+        }
+    }
+    if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) {
+        if (contig) {
+            return ctx->device->pipeline_contig_cpy_i32_f32;
+        } else {
+            return ctx->device->pipeline_cpy_i32_f32;
+        }
+    }
     if (src->type == GGML_TYPE_F32) {
         switch (to) {
         case GGML_TYPE_Q4_0:
@@ -12224,6 +12242,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     return true;
                 }
 
+                if (
+                    src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32 ||
+                    src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32
+                ) {
+                    return true;
+                }
+
                 // We can handle copying from a type to the same type if it's
                 // contiguous (memcpy). We use f16 or f32 shaders to do the copy,
                 // so the type/block size must be a multiple of 4.
index 27394c170fc6895bc1bdd88c6306b471eb00fb23..b6570e020360838d9cb23e614fc1fbe322075119 100644 (file)
@@ -560,10 +560,14 @@ void process_shaders() {
     string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
     string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
     string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
+    string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
     string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
     string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
     string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
     string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
+    string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
+    string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
 
     for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
         string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
index f4cd8f65cf8be81c51d60cdf9e643842c1479fab..e8e6e1000b2eb108dd2ecedffed954891c907fe4 100644 (file)
@@ -2457,6 +2457,13 @@ struct test_cpy : public test_case {
 
         return out;
     }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // test extended range of values to check if casting between f32 and i32 is consistent
+            init_tensor_uniform(t, -150.f, 150.f);
+        }
+    }
 };
 
 // GGML_OP_CONT
@@ -6007,6 +6014,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
             test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
         }
     }
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
 
     test_cases.emplace_back(new test_cont());
     test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));