#include "solve_tri.cuh"
#define MAX_N_FAST 64
+#define MAX_K_FAST 32
+
+static __global__ void get_batch_pointers(const float * A,
+ float * X,
+ const float ** A_ptrs,
+ float ** X_ptrs,
+ int64_t ne02,
+ int64_t total_batches,
+ size_t s02,
+ size_t s03,
+ size_t s2,
+ size_t s3) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= total_batches) {
+ return;
+ }
+
+ const int64_t i3 = idx / ne02;
+ const int64_t i2 = idx % ne02;
+
+ A_ptrs[idx] = A + i3 * s03 + i2 * s02;
+ X_ptrs[idx] = X + i3 * s3 + i2 * s2;
+}
+
+static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
+ const float * A,
+ const float * B,
+ float * X,
+ int n,
+ int k,
+ int64_t ne02,
+ int64_t ne03,
+ size_t s02,
+ size_t s03,
+ size_t s12,
+ size_t s13,
+ size_t s2,
+ size_t s3,
+ cudaStream_t stream) {
+ const float alpha = 1.0f;
+ const int64_t total_batches = ne02 * ne03;
+ if (total_batches == 0) {
+ return;
+ }
+
+ // Bulk copy B -> X (contiguous tensors)
+ if (X != B) {
+ const int64_t total_elements_BX = n * k * total_batches;
+ CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
+ }
+
+ const int id = ggml_cuda_get_device();
+
+ ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
+ ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
+
+ const float ** A_ptrs_dev = A_ptrs_alloc.get();
+ float ** X_ptrs_dev = X_ptrs_alloc.get();
+
+ get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
+ total_batches, s02, s03, s2, s3);
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+
+ // Yes, this is necessary, without this we get RMSE errors
+ CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
+ CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
+ CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
+
+ // revert to standard mode from common.cuh
+ CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
+
+ GGML_UNUSED_VARS(s12, s13);
+}
// ======================
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
- const int half = WARP_SIZE;
+ const int half = WARP_SIZE;
const int nrows_low = (n < half) ? n : half;
#pragma unroll
#pragma unroll
for (int row = half; row < n; ++row) {
- float sum = sA[row * n + lane] * x_low;
- const int j = half + lane;
+ float sum = sA[row * n + lane] * x_low;
+ const int j = half + lane;
if (j < row) {
sum += sA[row * n + j] * x_high;
}
for (int rr = 0; rr < 2; ++rr) {
const int row = rr * WARP_SIZE + lane;
if (row < n) {
- const float val = (row < half) ? x_low : x_high;
+ const float val = (row < half) ? x_low : x_high;
X_batch[row * k + col_idx] = val;
}
}
}
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
- const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
+ const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
+ const ggml_tensor * src1 = dst->src[1]; // B (n×k)
ggml_is_contiguous(src0);
ggml_is_contiguous(src1);
- const int64_t n = src0->ne[0];
- const int64_t k = src1->ne[0];
+ const int64_t n = src0->ne[0];
+ const int64_t k = src1->ne[0];
+ const int64_t ne02 = src0->ne[2];
+ const int64_t ne03 = src0->ne[3];
- GGML_ASSERT(n <= 64);
- GGML_ASSERT(k <= 32);
-
- solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
- src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
- src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
- dst->nb[3] / sizeof(float), ctx.stream());
+ if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
+ solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
+ src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
+ src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
+ dst->nb[3] / sizeof(float), ctx.stream());
+ } else {
+ solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
+ ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
+ src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
+ dst->nb[3] / sizeof(float), ctx.stream());
+ }
}
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 30, 30, 7, 1 }, { 8, 30, 7, 1 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 42, 42, 5, 2 }, { 10, 42, 5, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 10, 64, 2, 2 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 2, 2 }, { 64, 64, 2, 2 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 79, 79, 5, 3 }, { 417, 79, 5, 3 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 80, 80, 2, 8 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 79, 80, 2, 8 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 2, 8 }, { 81, 80, 2, 8 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 80, 80, 8, 8 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 79, 80, 8, 8 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 80, 80, 8, 8 }, { 81, 80, 8, 8 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 84, 84, 4, 4 }, { 32, 84, 4, 4 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 95, 95, 8, 8 }, { 40, 95, 8, 8 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 100, 100, 4, 4 }, { 41, 100, 4, 4 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 31, 128, 4, 4 }));
- test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 300, 64, 4, 4 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 4 }, { 32, 128, 4, 4 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 3, 4 }, { 32, 128, 3, 4 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 32, 128, 4, 1 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 }));
for (bool v : {false, true}) {
for (bool circular : {false, true}) {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
- test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
- test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 32, 64, 4, 4 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 2 }, { 32, 128, 4, 2 }));
// qwen3next with CHUNK_SIZE 64
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 8, 32 }, { 64, 64, 8, 32 }));
// qwen3next with CHUNK_SIZE 128
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 32 }, { 128, 128, 4, 32 }));
+ test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 256, 256, 4, 2 }, { 128, 256, 4, 2 }));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));