]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan/cuda: Fix im2col when KW!=KH (#14789)
authorJeff Bolz <redacted>
Mon, 21 Jul 2025 11:35:40 +0000 (06:35 -0500)
committerGitHub <redacted>
Mon, 21 Jul 2025 11:35:40 +0000 (13:35 +0200)
The tid is decomposed into "ow + ky*OW + kx*OW*KH". Change "ksize" to match.

ggml/src/ggml-cuda/im2col.cu
ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
tests/test-backend-ops.cpp

index 86a54e42bb7e64d4d2b574b9b51e7ba4b20c9ee5..5bb85b4807bcf5fdb2f8967728c74ab3fa24f116 100644 (file)
@@ -10,7 +10,7 @@ static  __global__ void im2col_kernel(
         return;
     }
 
-    const int64_t  ksize = OW * (KH > 1 ? KW : 1);
+    const int64_t  ksize = OW * KH;
     const int64_t  kx = i / ksize;
     const int64_t  kd = kx * ksize;
     const int64_t  ky = (i - kd) / OW;
index 17c7ccb90d001a44229ac59cecdf360cd2f51487..fdbcf7eba0fa588626499213b6947211aa9af6d2 100644 (file)
@@ -40,12 +40,10 @@ void main() {
     const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
     const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
     const int oh_s1 = int(oh) * p.s1;
-    const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
+    const uint ksize = p.OW * p.KH;
 
     const uint base_linear_idx = gidx * NUM_ITER;
 
-    const uint max_ky = ksize / p.OW;
-
     uint current_kx = base_linear_idx / ksize;
     const uint rem = base_linear_idx - (current_kx * ksize);
     uint current_ky = rem / p.OW;
@@ -76,7 +74,7 @@ void main() {
 
         if (++current_ix == p.OW) {
             current_ix = 0;
-            if (++current_ky == max_ky) {
+            if (++current_ky == p.KH) {
                 current_ky = 0;
                 current_kx++;
             }
index 731b4980af9473072fc10f0dcd4244708fbb8a73..a6d00542dd21efd19c1cd305e4313eb8c540b1bc 100644 (file)
@@ -5093,6 +5093,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
 
 // Conv_2D test cases
 #ifdef DETAILED_TESTS