]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : extend im2col f16 (ggml/1434)
authorDavid366AI <redacted>
Sun, 15 Mar 2026 19:50:56 +0000 (15:50 -0400)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 15:22:06 +0000 (17:22 +0200)
* examples/yolo: fix load_model memory leak

* fix/issue-1433 ggml_compute_forward_im2col_f16 assert error

* fix/issue-1433

ggml/src/ggml-cpu/ops.cpp

index 314cc1088a0625abb06d6d5310057c20459ad4d2..3f85e531daa6c2fec4d9460c222abf53b2cd71ac 100644 (file)
@@ -6205,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16(
     const ggml_tensor * src1 = dst->src[1];
 
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F16);
 
     GGML_TENSOR_BINARY_OP_LOCALS;
@@ -6236,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16(
     int ofs1 = is_2D ? nb12 : nb11;
 
     GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nb10 == sizeof(float));
+    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
 
     // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
     {
@@ -6249,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16(
 
                         // micro kernel
                         ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
+                        const float * const src_data_f32 = src1->type == GGML_TYPE_F32
+                            ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1)
+                            : nullptr; // [IH, IW]
+                        const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16
+                            ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1)
+                            : nullptr; // [IH, IW]
 
                         for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
                             for (int64_t ikw = 0; ikw < KW; ikw++) {
@@ -6259,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16(
                                 if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
                                     dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
                                 } else {
-                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
+                                    if (src_data_f32 != nullptr) {
+                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]);
+                                    } else {
+                                        dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw];
+                                    }
                                 }
                             }
                         }