#include "types.comp"
#ifndef LOAD_VEC_A
-#define LOAD_VEC_A 2
+#define LOAD_VEC_A 1
#endif
#ifndef LOAD_VEC_B
-#define LOAD_VEC_B 2
+#define LOAD_VEC_B 1
+#endif
+
+// Load 2 values at once without affecting index calculations through LOAD_VEC
+#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
+#define LOAD_VEC_BATCH_A 2
+#else
+#define LOAD_VEC_BATCH_A 1
+#endif
+#if !defined(ALIGNED)
+#define LOAD_VEC_BATCH_B 2
+#else
+#define LOAD_VEC_BATCH_B 1
#endif
#if !defined(TO_FLOAT_TYPE)
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
- const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
- const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
- const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
- const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
+ const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
+ const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
+ const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
+ const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
- const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
- const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
+ const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
+ const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
-#else // LOAD_VEC_A == 2
- const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
+#else // LOAD_VEC_BATCH_A == 2
+ const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
buf_a[buf_idx ] = aa.xy;
buf_a[buf_idx + 1] = aa.zw;
-#else // LOAD_VEC_A == 2
- const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
+#else // LOAD_VEC_BATCH_A == 2
+ const uint idx = pos_a + col * p.stride_a + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
-#else // LOAD_VEC_B == 2
- const uint idx = pos_b * 2 + col * p.stride_b + row * 2;
+#else // LOAD_VEC_BATCH_B == 2
+ const uint idx = pos_b + col * p.stride_b + row * 2;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
#endif
buf_b[buf_idx + 0] = bb.xy;
buf_b[buf_idx + 1] = bb.zw;
-#else // LOAD_VEC_B == 2
+#else // LOAD_VEC_BATCH_B == 2
const uint row_i = ic * BN + col;
const uint buf_idx = col * SHMEM_STRIDE + row;
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
const u16vec2 row_idx = row_ids[col];
- const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
TO_FLOAT_TYPE(data_b[idx + 1]));
} else if (row_i < _ne1 && block + row * 2 < end_k) {
const u16vec2 row_idx = row_ids[col];
- const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
+ const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
} else {
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);