]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
opencl: fix boundary handling for mul_mm (#16875)
authorlhez <redacted>
Thu, 30 Oct 2025 23:00:20 +0000 (16:00 -0700)
committerGitHub <redacted>
Thu, 30 Oct 2025 23:00:20 +0000 (16:00 -0700)
ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl
ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl
ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl

index 1a1bfe144f6109c4bcf62f6cf56d436786565f65..6982f8f514dd3cbab871c3fe5cce89cf59c371a0 100644 (file)
@@ -79,8 +79,8 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
 
     for (int block = 0; block < ne00; block += BK) {
         for (int l = 0; l < BM; l += loadstride_a) {
-            if (loadc_a + l < ne01) {
-            const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
+            if (ir*BM + loadc_a + l < ne01) {
+                const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
                 buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
                 buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
                 buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
@@ -94,7 +94,7 @@ kernel void kernel_mul_mm_f16_f32_l4_lm(
         }
 
         for (int l = 0; l < BN; l += loadstride_b) {
-            if (loadc_b + l < ne11) {
+            if (ic*BN + loadc_b + l < ne11) {
                 const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
                 buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
                 buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
index 39a5d4868ffaaa2e33bea15e26209c3dc0ee9435..d7d5ba647e7084dd2bccf3c61a15dd6ed2306ad2 100644 (file)
@@ -79,7 +79,7 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
 
     for (int block = 0; block < ne00; block += BK) {
         for (int l = 0; l < BM; l += loadstride_a) {
-            if (loadc_a + l < ne01) {
+            if (ir*BM + loadc_a + l < ne01) {
                 const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
                 buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
                 buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
@@ -94,7 +94,7 @@ kernel void kernel_mul_mm_f32_f32_l4_lm(
         }
 
         for (int l = 0; l < BN; l += loadstride_b) {
-            if (loadc_b + l < ne11) {
+            if (ic*BN + loadc_b + l < ne11) {
                 const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
                 buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
                 buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
index fd47e8a89dcef20f683757e1508696fbc4f6562a..147b66f6692a1842fd6be6157fb778a64d308c91 100644 (file)
@@ -78,7 +78,7 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
 
     for (int block = 0; block < ne00; block += BK) {
         for (int l = 0; l < BM; l += loadstride_a) {
-            if (loadc_a + l < ne01) {
+            if (ir*BM + loadc_a + l < ne01) {
                 int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
                 int ib  = idx / 8;
                 int iqs = idx % 8;
@@ -101,7 +101,7 @@ kernel void kernel_mul_mm_q8_0_f32_l4_lm(
         }
 
         for (int l = 0; l < BN; l += loadstride_b) {
-            if (loadc_b + l < ne11) {
+            if (ic*BN + loadc_b + l < ne11) {
                 int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
                 buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
                 buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;