From: Johannes Gäßler Date: Mon, 24 Jun 2024 20:15:33 +0000 (+0200) Subject: CUDA: fix MMQ writeback for int8 tensor cores (#8100) X-Git-Tag: upstream/0.0.4488~1270 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=3b099bcd9cbf2434f90cbe40eba6fa2189ed1d02;p=pkg%2Fggml%2Fsources%2Fllama.cpp CUDA: fix MMQ writeback for int8 tensor cores (#8100) --- diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 1fc948be..31fcbf13 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -2054,15 +2054,13 @@ static __device__ __forceinline__ void mmq_write_back_mma( static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); #endif // INT8_MMA_AVAILABLE - dst += (threadIdx.y % ntx) * mma_C::J*stride; - #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - const int j = j0 + mma_C::get_j(l); + const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); if (j > j_max) { continue;