xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
- if constexpr (I == 16 && J == 4) {
- int64_t * xi = (int64_t *) t.x;
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
- xi[0] = xs[0];
- }else if constexpr (I == 16 && J == 8) {
- int64_t * xi = (int64_t *) t.x;
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
- xi[0] = xs[0];
+ if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
+
+ } else if constexpr (std::is_same_v<T, int>) {
+ if constexpr (I == 16 && J == 4) {
+ int64_t * xi = (int64_t *) t.x;
+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
+ xi[0] = xs[0];
- const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
- xi[1] = xs1[0];
- }else{
+ }else if constexpr (I == 16 && J == 8) {
+ int64_t * xi = (int64_t *) t.x;
+ const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
+ xi[0] = xs[0];
+
+ const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
+ xi[1] = xs1[0];
+
+ }else{
+ NO_DEVICE_CODE;
+ }
+ } else {
NO_DEVICE_CODE;
}
#else
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const size_t nbs_ids = mmq_x*sizeof(int);
- const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+ const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
}