]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add support for sqrt on CUDA (#7953)
authorCalvin Laurenson <redacted>
Sun, 16 Jun 2024 22:23:04 +0000 (15:23 -0700)
committerGitHub <redacted>
Sun, 16 Jun 2024 22:23:04 +0000 (00:23 +0200)
* cuda sqrt support

* enable cuda in pca

* fix comments in pca

* add test

* add sqrt to ggml_backend_cuda_supports_op

* fix test

* new line

* Use F32 sqrtf instead of F64 sqrt

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
examples/cvector-generator/pca.hpp
ggml-cuda.cu
ggml-cuda/unary.cu
ggml-cuda/unary.cuh
tests/test-backend-ops.cpp

index 8b95cec374c2306bbf3bc8f7ecc2c435feefbec3..36eadaac26a1267f64382f132cd15415d6ffc36f 100644 (file)
@@ -64,15 +64,15 @@ struct pca_model {
     struct ggml_tensor * dev_eigenvector;
 
     pca_model(struct ggml_tensor * t_input) {
-// TODO: enable GPU support when support for GGML_OP_SQRT is added
-// #ifdef GGML_USE_CUDA
-//         fprintf(stderr, "%s: using CUDA backend\n", __func__);
-//         backend = ggml_backend_cuda_init(0); // init device 0
-//         if (!backend) {
-//             fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
-//         }
-// #endif
+#ifdef GGML_USE_CUDA
+        fprintf(stderr, "%s: using CUDA backend\n", __func__);
+        backend = ggml_backend_cuda_init(0); // init device 0
+        if (!backend) {
+            fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
+        }
+#endif
 
+// TODO: enable Metal support when support for GGML_OP_SQRT is added
 // #ifdef GGML_USE_METAL
 //         fprintf(stderr, "%s: using Metal backend\n", __func__);
 //         backend = ggml_backend_metal_init();
index 593fa4cdaa51431b4c2773a7686c613ffd5dbdaf..b8298ab205e6009b3e4313ffcb2ed2c2a6efd132 100644 (file)
@@ -2267,6 +2267,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SQR:
             ggml_cuda_op_sqr(ctx, dst);
             break;
+        case GGML_OP_SQRT:
+            ggml_cuda_op_sqrt(ctx, dst);
+            break;
         case GGML_OP_CLAMP:
             ggml_cuda_op_clamp(ctx, dst);
             break;
@@ -2830,6 +2833,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_RMS_NORM:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
+        case GGML_OP_SQRT:
         case GGML_OP_CLAMP:
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
index a5ff96320f23f4537053b3d80c6d93d79cb2b5ab..f9e208011e2a8f2b1387d2c2173e8056bf0ad068 100644 (file)
@@ -92,6 +92,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] * x[i];
 }
 
+static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sqrtf(x[i]);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -142,6 +151,11 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
     sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
+    sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
@@ -284,3 +298,17 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
index a1d07c04fcd4350a321690dbb6f824c8fdf7a2df..4cfb0479e7169a83a8b433b21b9b381aeb9158d9 100644 (file)
@@ -8,6 +8,7 @@
 #define CUDA_HARDSIGMOID_BLOCK_SIZE 256
 #define CUDA_HARDSWISH_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
+#define CUDA_SQRT_BLOCK_SIZE 256
 
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
@@ -28,3 +29,5 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 2b48e623e34764611bad5bde8c58d59b04ef8821..7c504e937a8518aedcd49783e0d7e627e9109808 100644 (file)
@@ -1063,6 +1063,33 @@ struct test_sqr : public test_case {
     }
 };
 
+// GGML_OP_SQRT
+struct test_sqrt : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sqrt(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_sqrt(ctx, a);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        // fill with positive values
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, 0.0f, 100.0f);
+        }
+    }
+};
+
 // GGML_OP_CLAMP
 struct test_clamp : public test_case {
     const ggml_type type;
@@ -2200,6 +2227,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     }
 
     test_cases.emplace_back(new test_sqr());
+    test_cases.emplace_back(new test_sqrt());
     test_cases.emplace_back(new test_clamp());
 
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10,  1,  1}, 5));