size_t src1_stride_size = sizeof(cuda_t);
- dim3 block_dims(ne13, ne12);
- k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
+ const int threads_x = 16;
+ const int threads_y = 16;
+ dim3 block_dims(threads_x, threads_y);
+
+ dim3 grid_dims(
+ (ne13 + threads_x - 1) / threads_x,
+ (ne12 + threads_y - 1) / threads_y
+ );
+ k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
src0_ptr, src1_ptr, dst_t,
ptrs_src.get(), ptrs_dst.get(),
ne12, ne13,
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
+
+ // test cases with large batch size
+ test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {1536, 1}, {1, 1}));
}
}
for (ggml_type type_a : other_types) {