Inside the container, execute the following commands:
```bash
-apt update -y && apt install -y bc cmake git python3.10-venv time unzip wget
+apt update -y && apt install -y bc cmake ccache git python3.10-venv time unzip wget
git config --global --add safe.directory /ws
GG_BUILD_MUSA=1 bash ./ci/run.sh /ci-results /ci-cache
```
if [ ! -z ${GG_BUILD_MUSA} ]; then
# Use qy1 by default (MTT S80)
MUSA_ARCH=${MUSA_ARCH:-21}
- CMAKE_EXTRA="-DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}"
+ CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_MUSA=ON -DMUSA_ARCHITECTURES=${MUSA_ARCH}"
fi
## helpers
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
+#ifdef _MSC_VER
+#define GGML_EXTENSION
+#else // _MSC_VER
+#define GGML_EXTENSION __extension__
+#endif // _MSC_VER
+
#define QK4_0 32
typedef struct {
ggml_half d; // delta
#define QK4_1 32
typedef struct {
- union {
+ GGML_EXTENSION union {
struct {
ggml_half d; // delta
ggml_half m; // min
#define QK5_1 32
typedef struct {
- union {
+ GGML_EXTENSION union {
struct {
ggml_half d; // delta
ggml_half m; // min
#define QK8_1 32
typedef struct {
- union {
+ GGML_EXTENSION union {
struct {
ggml_half d; // delta
ggml_half s; // d * sum(qs[i])
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
- union {
+ GGML_EXTENSION union {
struct {
ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins
// weight is represented as x = a * q + b
// Effectively 4.5 bits per weight
typedef struct {
- union {
+ GGML_EXTENSION union {
struct {
ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins
// weight is represented as x = a * q + b
// Effectively 5.5 bits per weight
typedef struct {
- union {
+ GGML_EXTENSION union {
struct {
ggml_half d; // super-block scale for quantized scales
ggml_half dmin; // super-block scale for quantized mins
__trap();
GGML_UNUSED(no_device_code); // suppress unused function warning
+
+#if defined(GGML_USE_MUSA)
+ __builtin_unreachable();
+#endif // defined(GGML_USE_MUSA)
}
#ifdef __CUDA_ARCH__
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
- if (blockIdx.y < ne01) { // src0
+ if (blockIdx.y < (unsigned)ne01) { // src0
int offset_src =
nidx +
blockIdx.y * ne0 +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
- if (blockIdx.z < ne02) { // src0
+ if (blockIdx.z < (unsigned)ne02) { // src0
int offset_src =
nidx +
blockIdx.y * ne0 +
}
}
dst[global_index] = accumulator;
+ GGML_UNUSED(p0); GGML_UNUSED(d0); GGML_UNUSED(src0_ne3);
+ GGML_UNUSED(src1_ne3); GGML_UNUSED(dst_ne3);
+ GGML_UNUSED(src1_ne1); GGML_UNUSED(dst_ne1);
+ GGML_UNUSED(src1_ne2); GGML_UNUSED(dst_ne2);
}
static void conv_transpose_1d_f32_f32_cuda(
const int p0 = 0;//opts[3];
const int d0 = 1;//opts[4];
- const int64_t kernel_size = ggml_nelements(src0);
- const int64_t input_size = ggml_nelements(src1);
const int64_t output_size = ggml_nelements(dst);
conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
return;
}
- const src_t * x = (src_t *) vx;
+ const src_t * x = (const src_t *) vx;
y[i] = x[i];
}
float vals[sizeof(int)] = {0.0f};
#pragma unroll
- for (int l = 0; l < sizeof(int); ++l) {
+ for (int l = 0; l < int(sizeof(int)); ++l) {
vals[l] = scale * x[4*threadIdx.x + l];
}
float amax = fabsf(vals[0]);
float sum = vals[0];
#pragma unroll
- for (int l = 1; l < sizeof(int); ++l) {
+ for (int l = 1; l < int(sizeof(int)); ++l) {
amax = fmaxf(amax, fabsf(vals[l]));
sum += vals[l];
}
if (d != 0.0f) {
#pragma unroll
- for (int l = 0; l < sizeof(int); ++l) {
+ for (int l = 0; l < int(sizeof(int)); ++l) {
q8[l] = roundf(vals[l] / d);
}
}
float VKQ_denominator = 0.0f;
for (int l = 0; l < parallel_blocks; ++l) {
const float diff = meta[l].x - kqmax;
- const float KQ_max_scale = expf(diff);
+ float KQ_max_scale = expf(diff);
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
}
+[[noreturn]]
static void on_no_fattn_vec_case(const int D) {
if (D == 64) {
fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
#endif // CP_ASYNC_AVAILABLE
#else
+ GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
+ GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
+ GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV);
+ GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+ GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
+ GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
+ GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
+ GGML_UNUSED(kb0);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__syncthreads();
}
#else
+ GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
+ GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
+ GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
+ GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask);
+ GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+ GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
+ GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+ GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
+ GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+ GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
}
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8);
-
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16);
-
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32);
-
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64);
-DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8)
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16)
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32)
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64)
// Kernels with ncols == 128 are only 4% faster due to register pressure.
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
-// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
}
}
#else
- NO_DEVICE_CODE;
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+ GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+ GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+ GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+ GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+ GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+ GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+ GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+ NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
}
}
#else
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+ GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+ GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+ GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+ GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+ GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+ GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+ GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
- NO_DEVICE_CODE;
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+ GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
+ GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
+ GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+ GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+ GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
+ GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
+ GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+ GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+ NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+ GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
+ GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
+ GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
+ GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
+ GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
+ GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
+ GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
}
#else
- NO_DEVICE_CODE;
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
+ GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
+ GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
+ GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
+ GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
+ GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
+ GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
+ GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
+ GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
+ NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
}
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(ret) : "r"(x));
#else
+ GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
return ret;
: "l"(xs));
#else
load_generic(xs0, stride);
+ GGML_UNUSED(t);
#endif // NEW_MMA_AVAILABLE
}
}
}
#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
}
#pragma unroll
- for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+ for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
- if (k01 < WARP_SIZE/2) {
- constexpr int ns = 2;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
- &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
- } else {
- constexpr int ns = 1;
- sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
- &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
- &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
- &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
- }
+ constexpr int ns = 2;
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ }
+ }
+ }
+
+ // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
+ // As a workaround 2 separate loops are used instead.
+#pragma unroll
+ for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ constexpr int ns = 1;
+ sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
}
}
}
}
#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
const float d = bxi->d;
#pragma unroll
- for (int l = 0; l < sizeof(int); ++l) {
+ for (int l = 0; l < int(sizeof(int)); ++l) {
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
}
#else
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
#pragma unroll
- for (int l = 0; l < sizeof(int); ++l) {
+ for (int l = 0; l < int(sizeof(int)); ++l) {
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
}
}
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
#pragma unroll
- for (int l = 0; l < sizeof(int); ++l) {
+ for (int l = 0; l < int(sizeof(int)); ++l) {
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
}
}
}
}
#else
- GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
+ GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
} else {
write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
}
+
+ GGML_UNUSED(ne00); GGML_UNUSED(ne10);
}
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
- if (it != blockIdx.x || jt != blockIdx.y) {
+ if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) {
continue;
}
template <ggml_type type>
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
- const int nsm = ggml_cuda_info().devices[id].nsm;
const int cc = ggml_cuda_info().devices[id].cc;
const int smpbo = ggml_cuda_info().devices[id].smpbo;
__syncthreads();
}
- float sumf;
+ float sumf = 0.0f;
if constexpr (std::is_same<T, half>::value) {
const half2 * x2 = (const half2 *) x;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// partial sum for each thread
- float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
+ float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = (const block_q8_1 *) vy;
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
}
- if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
+ if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
}
}
+
+ GGML_UNUSED(nrows_x);
}
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
- if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
+ if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
int offset_src =
nidx +
blockIdx.y * ne00 +
int i02 = i12 / sf2;
int i03 = i13 / sf3;
- dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
+ dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
}
static void upscale_f32_cuda(const float * x, float * dst,