layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
-shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
+shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
+shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+uint csel = 0;
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+ csel ^= 1;
- barrier();
if (!all_threads) { // when we don't have enough blocks to use all threads
if (i < num_blocks_per_row) {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
- sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
- sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
+ sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
+ sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
}
barrier();
continue;
} else {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
- sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
- sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
+ sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
+ sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
barrier();
}
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
- sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ],
- fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
- fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ],
- fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
- fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ],
- fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
- fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ],
- fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
- sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im],
- fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im],
- fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im],
- fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im],
- fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im],
- fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im],
- fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im],
- fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
+ sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ],
+ fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
+ fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ],
+ fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
+ fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ],
+ fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
+ fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ],
+ fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
+ sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im],
+ fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im],
+ fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im],
+ fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im],
+ fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im],
+ fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im],
+ fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
+ fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
+shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+uint csel = 0;
void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+ csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads
- barrier();
if (i < num_blocks_per_row)
- sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
+ sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
if (i >= num_blocks_per_row)
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
if (all_threads) {
- barrier();
- sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
+ sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
}
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
- sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
- fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
- fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
- fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
- fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
- fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
- fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
- fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
+ sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
+ fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
+ fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
+ fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
+ fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
+ fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
+ fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
+ fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
}
temp[j][n] = fma(d, sum, temp[j][n]);
}
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
+shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+uint csel = 0;
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+ csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads
- barrier();
if (i < num_blocks_per_row)
- sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
+ sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
if (i >= num_blocks_per_row)
const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
if (all_threads) {
- barrier();
- sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
+ sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
}
sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
}
- temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
+ temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]);
}
}
}