]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix MMQ writeback for int8 tensor cores (#8100)
authorJohannes Gäßler <redacted>
Mon, 24 Jun 2024 20:15:33 +0000 (22:15 +0200)
committerGitHub <redacted>
Mon, 24 Jun 2024 20:15:33 +0000 (22:15 +0200)
ggml-cuda/mmq.cuh

index 1fc948be5bbe838878a37b5fc767b1411a06e434..31fcbf1397b6bb15475af811299efc8064006dd3 100644 (file)
@@ -2054,15 +2054,13 @@ static __device__ __forceinline__ void mmq_write_back_mma(
     static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
 #endif // INT8_MMA_AVAILABLE
 
-    dst += (threadIdx.y % ntx) * mma_C::J*stride;
-
 #pragma unroll
     for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
 #pragma unroll
         for (int n = 0; n < ntx; ++n) {
 #pragma unroll
             for (int l = 0; l < mma_C::ne; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+                const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
 
                 if (j > j_max) {
                     continue;