]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml/ex: calculate accuracy in graph, adapt MNIST (#980)
authorJohannes Gäßler <redacted>
Thu, 3 Oct 2024 15:29:59 +0000 (17:29 +0200)
committerGitHub <redacted>
Thu, 3 Oct 2024 15:29:59 +0000 (17:29 +0200)
14 files changed:
examples/mnist/mnist-common.cpp
examples/mnist/mnist-common.h
examples/mnist/mnist-eval.cpp
include/ggml.h
src/ggml-cuda.cu
src/ggml-cuda/argmax.cu [new file with mode: 0644]
src/ggml-cuda/argmax.cuh [new file with mode: 0644]
src/ggml-cuda/common.cuh
src/ggml-cuda/count-equal.cu [new file with mode: 0644]
src/ggml-cuda/count-equal.cuh [new file with mode: 0644]
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml.c
tests/test-backend-ops.cpp

index 4b7374e05ffbae9c7759a7d2b8c620985739f6e1..c70dfd7fdc0e57d1660ed2d31ae63aa1ab45b70b 100644 (file)
@@ -480,24 +480,47 @@ void mnist_model_build(mnist_model & model, const int nbatch_logical, const int
     GGML_ASSERT(model.loss->ne[1] == 1);
     GGML_ASSERT(model.loss->ne[2] == 1);
     GGML_ASSERT(model.loss->ne[3] == 1);
+
+    model.pred = ggml_argmax(model.ctx_compute, model.logits);
+    ggml_set_name(model.pred, "predictions");
+    ggml_set_output(model.pred);
+    GGML_ASSERT(model.pred->type == GGML_TYPE_I32);
+    GGML_ASSERT(model.pred->ne[0] == model.nbatch_physical);
+    GGML_ASSERT(model.pred->ne[1] == 1);
+    GGML_ASSERT(model.pred->ne[2] == 1);
+    GGML_ASSERT(model.pred->ne[3] == 1);
+
+    model.acc_count = ggml_count_equal(model.ctx_compute, model.pred, ggml_argmax(model.ctx_compute, model.labels));
+    ggml_set_name(model.acc_count, "accuracy_count");
+    ggml_set_output(model.acc_count);
+    GGML_ASSERT(model.acc_count->type == GGML_TYPE_I64);
+    GGML_ASSERT(model.acc_count->ne[0] == 1);
+    GGML_ASSERT(model.acc_count->ne[1] == 1);
+    GGML_ASSERT(model.acc_count->ne[2] == 1);
+    GGML_ASSERT(model.acc_count->ne[3] == 1);
 }
 
 mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, const float * labels, const int nex) {
     mnist_eval_result result;
 
     struct ggml_cgraph * gf = ggml_new_graph(model.ctx_compute);
+    // The outputs are diverging branches of the graphs, therefore multiple calls to ggml_build_forward_expand are needed.
     ggml_build_forward_expand(gf, model.loss);
+    ggml_build_forward_expand(gf, model.pred);
+    ggml_build_forward_expand(gf, model.acc_count);
 
     model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend);
 
     {
         const int64_t t_start_us = ggml_time_us();
 
-        float loss;
-        std::vector<float> logits(model.nbatch_physical*MNIST_NCLASSES);
+        float                tmp_loss;
+        std::vector<int32_t> tmp_pred(model.nbatch_physical);
+        int64_t              tmp_acc_count;
 
-        GGML_ASSERT(sizeof(loss)  == ggml_nbytes(model.loss));
-        GGML_ASSERT(logits.size() == ggml_nelements(model.logits));
+        GGML_ASSERT(sizeof(tmp_loss)                    == ggml_nbytes(model.loss));
+        GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred));
+        GGML_ASSERT(sizeof(tmp_acc_count)               == ggml_nbytes(model.acc_count));
 
         GGML_ASSERT(nex % model.nbatch_physical == 0);
         for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch_physical) {
@@ -506,15 +529,14 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
 
             ggml_backend_graph_compute(model.backend, gf);
 
-            ggml_backend_tensor_get(model.loss,   &loss,         0, ggml_nbytes(model.loss));
-            ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));
-
-            result.loss.push_back(loss);
+            ggml_backend_tensor_get(model.loss,      &tmp_loss,       0, ggml_nbytes(model.loss));
+            ggml_backend_tensor_get(model.pred,      tmp_pred.data(), 0, ggml_nbytes(model.pred));
+            ggml_backend_tensor_get(model.acc_count, &tmp_acc_count,  0, ggml_nbytes(model.acc_count));
 
-            for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) {
-                const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES;
-                result.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb);
-            }
+            result.loss.push_back(tmp_loss);
+            result.pred.insert(result.pred.end(), tmp_pred.begin(), tmp_pred.end());
+            result.ncorrect += tmp_acc_count;
+            result.ntotal   += model.nbatch_physical;
         }
 
         const int64_t t_total_us = ggml_time_us() - t_start_us;
@@ -530,13 +552,18 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co
 void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split) {
     const int64_t t_start_us = ggml_time_us();
 
+    const bool accumulate = model.nbatch_physical != model.nbatch_logical;
+
     // gf == graph forward, forward pass only.
     struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
+    // The outputs are diverging branches of the graphs, therefore multiple calls to ggml_build_forward_expand are needed.
     ggml_build_forward_expand(gf, model.loss);
+    ggml_build_forward_expand(gf, model.pred);
+    ggml_build_forward_expand(gf, model.acc_count);
 
     // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
     struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf);
-    ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true);
+    ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, accumulate);
 
     // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
     struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad);
@@ -551,10 +578,16 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
         fprintf(stderr, "%s: epoch %02d start...", __func__, epoch);
         const int64_t t_start_us = ggml_time_us();
 
-        float loss;
-        std::vector<float> logits(model.nbatch_physical*MNIST_NCLASSES);
         int iex0 = 0;
 
+        float                tmp_loss;
+        std::vector<int32_t> tmp_pred(model.nbatch_physical);
+        int64_t              tmp_acc_count;
+
+        GGML_ASSERT(sizeof(tmp_loss)                    == ggml_nbytes(model.loss));
+        GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred));
+        GGML_ASSERT(sizeof(tmp_acc_count)               == ggml_nbytes(model.acc_count));
+
         mnist_eval_result result_train;
         for (; iex0 < iex_split; iex0 += model.nbatch_physical) {
             ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT,   0, ggml_nbytes(model.images));
@@ -570,15 +603,14 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
                 ggml_graph_reset(gb_grad); // Set gradients to zero, do not reset optimizer.
             }
 
-            ggml_backend_tensor_get(model.loss,   &loss,         0, ggml_nbytes(model.loss));
-            ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));
-
-            result_train.loss.push_back(loss);
+            ggml_backend_tensor_get(model.loss,      &tmp_loss,       0, ggml_nbytes(model.loss));
+            ggml_backend_tensor_get(model.pred,      tmp_pred.data(), 0, ggml_nbytes(model.pred));
+            ggml_backend_tensor_get(model.acc_count, &tmp_acc_count,  0, ggml_nbytes(model.acc_count));
 
-            for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) {
-                const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES;
-                result_train.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb);
-            }
+            result_train.loss.push_back(tmp_loss);
+            result_train.pred.insert(result_train.pred.end(), tmp_pred.begin(), tmp_pred.end());
+            result_train.ncorrect += tmp_acc_count;
+            result_train.ntotal   += model.nbatch_physical;
         }
 
         mnist_eval_result result_val;
@@ -588,20 +620,19 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
 
             ggml_backend_graph_compute(model.backend, gf); // For the validation set, only the forward pass is needed.
 
-            ggml_backend_tensor_get(model.loss,   &loss,         0, ggml_nbytes(model.loss));
-            ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits));
-
-            result_val.loss.push_back(loss);
+            ggml_backend_tensor_get(model.loss,      &tmp_loss,       0, ggml_nbytes(model.loss));
+            ggml_backend_tensor_get(model.pred,      tmp_pred.data(), 0, ggml_nbytes(model.pred));
+            ggml_backend_tensor_get(model.acc_count, &tmp_acc_count,  0, ggml_nbytes(model.acc_count));
 
-            for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) {
-                const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES;
-                result_val.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb);
-            }
+            result_val.loss.push_back(tmp_loss);
+            result_val.pred.insert(result_val.pred.end(), tmp_pred.begin(), tmp_pred.end());
+            result_val.ncorrect += tmp_acc_count;
+            result_val.ntotal   += model.nbatch_physical;
         }
 
         {
             const double loss_mean = mnist_loss(result_train).first;
-            const double percent_correct = 100.0 * mnist_accuracy(result_train, labels + 0*MNIST_NCLASSES).first;
+            const double percent_correct = 100.0 * mnist_accuracy(result_train).first;
 
             const int64_t t_epoch_us = ggml_time_us() - t_start_us;
             const double t_epoch_s = 1e-6*t_epoch_us;
@@ -610,7 +641,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *
 
         if (iex_split < nex) {
             const std::pair<double, double> loss = mnist_loss(result_val);
-            const std::pair<double, double> acc  = mnist_accuracy(result_val, labels + iex_split*MNIST_NCLASSES);
+            const std::pair<double, double> acc  = mnist_accuracy(result_val);
 
             fprintf(stderr, ", val_loss=%.6lf+-%.6lf, val_acc=%.2f+-%.2f%%", loss.first, loss.second, 100.0*acc.first, 100.0*acc.second);
         }
@@ -668,7 +699,7 @@ void mnist_model_save(mnist_model & model, const std::string & fname) {
 
 std::pair<double, double> mnist_loss(const mnist_eval_result & result) {
     const size_t nbatches = result.loss.size();
-    GGML_ASSERT(nbatches >= 1);
+    GGML_ASSERT(nbatches >= 2);
 
     double sum         = 0.0;
     double sum_squared = 0.0;
@@ -684,20 +715,12 @@ std::pair<double, double> mnist_loss(const mnist_eval_result & result) {
     return std::make_pair(mean, uncertainty);
 }
 
-std::pair<double, double> mnist_accuracy(const mnist_eval_result & result, const float * labels) {
-    const size_t nex = result.pred.size();
-    GGML_ASSERT(nex >= 1);
-
-    size_t ncorrect = 0;
-    for (size_t iex = 0; iex < nex; ++iex) {
-        const float * labels_iex = labels + iex*MNIST_NCLASSES;
-        const int32_t label = std::max_element(labels_iex, labels_iex + MNIST_NCLASSES) - labels_iex;
-
-        ncorrect += result.pred[iex] == label;
-    }
+std::pair<double, double> mnist_accuracy(const mnist_eval_result & result) {
+    GGML_ASSERT(result.ntotal >= result.ncorrect);
+    GGML_ASSERT(result.ntotal >= 2);
 
-    const double fraction_correct = ((double) ncorrect) / ((double) nex);
-    const double uncertainty      = sqrt(fraction_correct * (1.0 - fraction_correct) / (nex - 1));
+    const double fraction_correct = ((double) result.ncorrect) / ((double) result.ntotal);
+    const double uncertainty      = sqrt(fraction_correct * (1.0 - fraction_correct) / (result.ncorrect - 1));
 
     return std::make_pair(fraction_correct, uncertainty);
 }
index 72638a4094c47519bc7878a244c98c479f22e9db..a6239a426bbc3c9b64e746f0f820422bba51bbfa 100644 (file)
@@ -31,11 +31,13 @@ struct mnist_model {
     int nbatch_logical;
     int nbatch_physical;
 
-    struct ggml_tensor  * images = nullptr;
-    struct ggml_tensor  * labels = nullptr;
-    struct ggml_tensor  * logits = nullptr;
-    struct ggml_tensor  * probs  = nullptr;
-    struct ggml_tensor  * loss   = nullptr;
+    struct ggml_tensor  * images    = nullptr;
+    struct ggml_tensor  * labels    = nullptr;
+    struct ggml_tensor  * logits    = nullptr;
+    struct ggml_tensor  * probs     = nullptr;
+    struct ggml_tensor  * loss      = nullptr;
+    struct ggml_tensor  * pred      = nullptr;
+    struct ggml_tensor  * acc_count = nullptr;
 
     struct ggml_tensor * fc1_weight = nullptr;
     struct ggml_tensor * fc1_bias   = nullptr;
@@ -108,6 +110,8 @@ struct mnist_eval_result {
 
     std::vector<float>   loss;
     std::vector<int32_t> pred;
+    int64_t              ncorrect = 0;
+    int64_t              ntotal   = 0;
 };
 
 bool mnist_image_load(const std::string & fname, float * buf, const int nex);
@@ -124,4 +128,4 @@ void              mnist_model_train(mnist_model & model, const float * images, c
 void              mnist_model_save(mnist_model & model, const std::string & fname);
 
 std::pair<double, double> mnist_loss(const mnist_eval_result & result);
-std::pair<double, double> mnist_accuracy(const mnist_eval_result & result, const float * labels);
+std::pair<double, double> mnist_accuracy(const mnist_eval_result & result);
index ae2cadbc5f38670c21655fe5aa29cee0ecf5e020..a5f9676faf085185637de3710488c41bc3249144 100644 (file)
@@ -55,7 +55,7 @@ int main(int argc, char ** argv) {
             std::pair<double, double> result_loss = mnist_loss(result_eval);
             fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second);
 
-            std::pair<double, double> result_acc = mnist_accuracy(result_eval, labels.data());
+            std::pair<double, double> result_acc = mnist_accuracy(result_eval);
             fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second);
 
             return 0;
@@ -79,7 +79,7 @@ int main(int argc, char ** argv) {
     std::pair<double, double> result_loss = mnist_loss(result_eval);
     fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second);
 
-    std::pair<double, double> result_acc = mnist_accuracy(result_eval, labels.data());
+    std::pair<double, double> result_acc = mnist_accuracy(result_eval);
     fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second);
 
     return 0;
index ce3d92cb2e0f060dc9dfc39ea3e16b0943a7fb67..128c63e68b3a7c10aa7e654438bc9616ea2e69de 100644 (file)
@@ -466,6 +466,7 @@ extern "C" {
         GGML_OP_SUM_ROWS,
         GGML_OP_MEAN,
         GGML_OP_ARGMAX,
+        GGML_OP_COUNT_EQUAL,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT_BACK,
         GGML_OP_CONCAT,
@@ -1004,6 +1005,12 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // count number of equal elements in a and b
+    GGML_API struct ggml_tensor * ggml_count_equal(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // if a is the same shape as b, and a is not parameter, return a
     // otherwise, return a new tensor: repeat(a) to fit in b
     GGML_API struct ggml_tensor * ggml_repeat(
index 6efdab14c36193c2d0b9735975e2cec912d77a9a..77444dc41d7dac8a4b685f6d1eec21b5073898f1 100644 (file)
@@ -5,12 +5,14 @@
 #include "ggml-cuda/common.cuh"
 #include "ggml-cuda/acc.cuh"
 #include "ggml-cuda/arange.cuh"
+#include "ggml-cuda/argmax.cuh"
 #include "ggml-cuda/argsort.cuh"
 #include "ggml-cuda/binbcast.cuh"
 #include "ggml-cuda/clamp.cuh"
 #include "ggml-cuda/concat.cuh"
 #include "ggml-cuda/conv-transpose-1d.cuh"
 #include "ggml-cuda/convert.cuh"
+#include "ggml-cuda/count-equal.cuh"
 #include "ggml-cuda/cpy.cuh"
 #include "ggml-cuda/cross-entropy-loss.cuh"
 #include "ggml-cuda/diagmask.cuh"
@@ -2178,6 +2180,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
     }
 
     switch (dst->op) {
+        case GGML_OP_ARGMAX:
+            ggml_cuda_argmax(ctx, dst);
+            break;
+        case GGML_OP_COUNT_EQUAL:
+            ggml_cuda_count_equal(ctx, dst);
+            break;
         case GGML_OP_REPEAT:
             ggml_cuda_op_repeat(ctx, dst);
             break;
@@ -2929,6 +2937,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 return false;
             } break;
         case GGML_OP_DUP:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+            } break;
+        case GGML_OP_ARGMAX:
+        case GGML_OP_COUNT_EQUAL:
+            {
+                return true;
+            } break;
         case GGML_OP_REPEAT:
             {
                 ggml_type src0_type = op->src[0]->type;
diff --git a/src/ggml-cuda/argmax.cu b/src/ggml-cuda/argmax.cu
new file mode 100644 (file)
index 0000000..aab04ec
--- /dev/null
@@ -0,0 +1,79 @@
+#include "common.cuh"
+#include "argmax.cuh"
+#include "sum.cuh"
+
+#include <cstdint>
+
+static __global__ void argmax_f32(
+    const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
+
+    int argmax_thread = 0;
+    const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
+
+#pragma unroll
+    for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
+        const int64_t row = row0 + row1;
+
+        if (row >= nrows) {
+            break;
+        }
+
+        float maxval = -FLT_MAX;
+        int   argmax = -1;
+
+        for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
+            const float val        = x[row*ncols + col];
+            const int   bigger     = val > maxval;
+            const int   not_bigger = bigger ^ 0x00000001;
+
+            maxval = maxval*not_bigger + val*bigger;
+            argmax = argmax*not_bigger + col*bigger;
+        }
+
+#pragma unroll
+        for (int mask = 16; mask > 0; mask >>= 1) {
+            const float val        = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
+            const int   col        = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
+            const int   bigger     = val > maxval;
+            const int   not_bigger = bigger ^ 0x00000001;
+
+            maxval = maxval*not_bigger + val*bigger;
+            argmax = argmax*not_bigger + col*bigger;
+        }
+
+        const int store = row1 == threadIdx.x;
+        argmax_thread += store*argmax;
+    }
+
+    const int row = row0 + threadIdx.x;
+
+    if (row >= nrows) {
+        return;
+    }
+
+    dst[row] = argmax_thread;
+}
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    int32_t     * dst_d  = (int32_t     *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
+
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(num_blocks, 1, 1);
+
+    argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
+}
diff --git a/src/ggml-cuda/argmax.cuh b/src/ggml-cuda/argmax.cuh
new file mode 100644 (file)
index 0000000..5b7223a
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 6a4bcdba095736ae98396566936f7a024b0a02c7..dd203fcded3aafbd761c54125e38922edcc2b218 100644 (file)
@@ -175,6 +175,18 @@ static __device__ void no_device_code(
 #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
 #endif // __CUDA_ARCH__
 
+static __device__ __forceinline__ int warp_reduce_sum(int x) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
+    return __reduce_add_sync(0xffffffff, x);
+#else
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    }
+    return x;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
+}
+
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
diff --git a/src/ggml-cuda/count-equal.cu b/src/ggml-cuda/count-equal.cu
new file mode 100644 (file)
index 0000000..ffb053b
--- /dev/null
@@ -0,0 +1,64 @@
+#include "common.cuh"
+#include "count-equal.cuh"
+
+#include <cstdint>
+
+template <typename T>
+static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) {
+    const int64_t i0 = (int64_t) blockIdx.x*dk;
+    const int64_t i1 = min(i0 + dk, k);
+
+    int nequal = 0;
+
+    for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) {
+        const T xi = x[i];
+        const T yi = y[i];
+        nequal += xi == yi;
+    }
+
+    nequal = warp_reduce_sum(nequal);
+
+    if (threadIdx.x != 0) {
+        return;
+    }
+
+    atomicAdd((int *) dst, nequal);
+}
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == src1->type);
+    GGML_ASSERT( dst->type == GGML_TYPE_I64);
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    int64_t * dst_d  = (int64_t *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
+    const int64_t dne = GGML_PAD(ne / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
+
+    CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));
+
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1);
+
+    switch (src0->type) {
+        case GGML_TYPE_I32: {
+            const int * src0_d = (const int *) src0->data;
+            const int * src1_d = (const int *) src1->data;
+            count_equal<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_d, dne, ne);
+        } break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+}
diff --git a/src/ggml-cuda/count-equal.cuh b/src/ggml-cuda/count-equal.cuh
new file mode 100644 (file)
index 0000000..8467da7
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 342f2eb665312d7dd93003a60bcd195ac032ac2e..5af02c7ecbed75ab6f271d04eca360785dd4c584 100644 (file)
@@ -259,7 +259,7 @@ static __global__ void flash_attn_tile_ext_f16(
         }
 
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
-        kqsum_j = warp_reduce_sum(kqsum_j);
+        kqsum_j = warp_reduce_sum((float)kqsum_j);
 
 #pragma unroll
         for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
index 448a9a9054cca112cdfd0b0c4de1e1c6152a38af..2ed6509acb82d2b33a0f3dede3532e7b25a81a2a 100644 (file)
@@ -196,7 +196,7 @@ static __global__ void flash_attn_vec_ext_f16(
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
                 half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
-                sum = warp_reduce_sum(sum);
+                sum = warp_reduce_sum((float)sum);
 
                 if (use_logit_softcap) {
                     sum = logit_softcap*tanhf(sum);
@@ -265,7 +265,7 @@ static __global__ void flash_attn_vec_ext_f16(
 
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
-        kqsum[j] = warp_reduce_sum(kqsum[j]);
+        kqsum[j] = warp_reduce_sum((float)kqsum[j]);
         if (threadIdx.x == 0) {
             kqsum_shared[j][threadIdx.y] = kqsum[j];
         }
@@ -280,7 +280,7 @@ static __global__ void flash_attn_vec_ext_f16(
         }
 
         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
-        kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+        kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
 
         half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
         if (parallel_blocks == 1) {
index fd411a64b1868b412b61f98783b415fa80884477..5038b689b1e0b278255280aac802bf52318f7303 100644 (file)
@@ -2957,6 +2957,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "SUM_ROWS",
     "MEAN",
     "ARGMAX",
+    "COUNT_EQUAL",
     "REPEAT",
     "REPEAT_BACK",
     "CONCAT",
@@ -3030,7 +3031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3051,6 +3052,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "Σx_k",
     "Σx/n",
     "argmax(x)",
+    "count_equal(x)",
     "repeat(x)",
     "repeat_back(x)",
     "concat(x, y)",
@@ -3124,7 +3126,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -5185,6 +5187,23 @@ struct ggml_tensor * ggml_argmax(
     return result;
 }
 
+// ggml_count_equal
+
+struct ggml_tensor * ggml_count_equal(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);
+
+    result->op     = GGML_OP_COUNT_EQUAL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
 // ggml_repeat
 
 struct ggml_tensor * ggml_repeat(
@@ -10772,6 +10791,86 @@ static void ggml_compute_forward_argmax(
     }
 }
 
+// ggml_compute_forward_count_equal
+
+static void ggml_compute_forward_count_equal_i32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    GGML_ASSERT(src0->type == GGML_TYPE_I32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(dst->type == GGML_TYPE_I64);
+
+    const int64_t nr = ggml_nrows(src0);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    int64_t * sums = (int64_t *) params->wdata;
+    int64_t sum_thread = 0;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 =  ir                        / (ne02*ne01);
+        const int64_t i02 = (ir - i03*ne03)            /       ne01;
+        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
+
+        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
+        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
+
+        for (int64_t i00 = 0; i00 < ne00; ++i00) {
+            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
+            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
+
+            sum_thread += val0 == val1;
+        }
+    }
+    if (ith != 0) {
+        sums[ith] = sum_thread;
+    }
+    ggml_barrier(params->threadpool);
+
+    if (ith != 0) {
+        return;
+    }
+
+    for (int ith_other = 1; ith_other < nth; ++ith_other) {
+        sum_thread += sums[ith_other];
+    }
+    *((int64_t *) dst->data) = sum_thread;
+}
+
+static void ggml_compute_forward_count_equal(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_count_equal_i32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_repeat
 
 static void ggml_compute_forward_repeat_f32(
@@ -17146,6 +17245,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_argmax(params, tensor);
             } break;
+        case GGML_OP_COUNT_EQUAL:
+            {
+                ggml_compute_forward_count_equal(params, tensor);
+            } break;
         case GGML_OP_REPEAT:
             {
                 ggml_compute_forward_repeat(params, tensor);
@@ -17896,6 +17999,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_MEAN:
         case GGML_OP_ARGMAX:
+        case GGML_OP_COUNT_EQUAL:
             {
                 GGML_ABORT("fatal error"); // TODO: implement
             }
@@ -18669,6 +18773,10 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
     for (int i = 0; i < gf->n_nodes; ++i) {
         struct ggml_tensor * node = gf->nodes[i];
 
+        if (node->type == GGML_TYPE_I32) {
+            continue;
+        }
+
         bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
         bool ignore_src[GGML_MAX_SRC] = {false};
         switch (node->op) {
@@ -19072,6 +19180,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_ARGMAX:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_COUNT_EQUAL:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_REPEAT:
         case GGML_OP_REPEAT_BACK:
         case GGML_OP_LEAKY_RELU:
@@ -19570,6 +19685,10 @@ struct ggml_cplan ggml_graph_plan(
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
                     }
                 } break;
+            case GGML_OP_COUNT_EQUAL:
+                {
+                    cur = ggml_type_size(node->type)*n_tasks;
+                } break;
             case GGML_OP_MUL_MAT:
                 {
                     const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
index 95d983aa083c3811772949805d0bf7d5a08c36ce..120eba10749581eb7acaab9e768c62955a7ba8ed 100644 (file)
@@ -116,6 +116,11 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
     } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
         // This is going to create some weird integers though.
         ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor));
+    } else if (tensor->type == GGML_TYPE_I64) {
+        // Integers with a size of 8 bytes can be set by mirroring the float data, the specific values are again not really meaningful.
+        const size_t nbytes_half = ggml_nbytes(tensor)/2;
+        ggml_backend_tensor_set(tensor, data.data(), 0*nbytes_half, nbytes_half);
+        ggml_backend_tensor_set(tensor, data.data(), 1*nbytes_half, nbytes_half);
     } else {
         GGML_ABORT("fatal error");
     }
@@ -145,6 +150,8 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
                         tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));
                     } else if (t->type == GGML_TYPE_F32) {
                         tv.push_back(*(float *) &buf[i]);
+                    } else if (t->type == GGML_TYPE_I64) {
+                        tv.push_back((float)*(int64_t *) &buf[i]);
                     } else if (t->type == GGML_TYPE_I32) {
                         tv.push_back((float)*(int32_t *) &buf[i]);
                     } else if (t->type == GGML_TYPE_I16) {
@@ -1119,6 +1126,71 @@ struct test_get_rows : public test_case {
     }
 };
 
+// GGML_OP_ARGMAX
+struct test_argmax : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_argmax(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 100, 1, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_argmax(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    double max_nmse_err() override {
+        return 0.0;
+    }
+};
+
+// GGML_OP_COUNT_EQUAL
+struct test_count_equal : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_count_equal(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {4, 500, 1, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * a_argmax = ggml_argmax(ctx, a);
+        ggml_set_name(a_argmax, "a_argmax");
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(b, "b");
+
+        ggml_tensor * b_argmax = ggml_argmax(ctx, a);
+        ggml_set_name(b_argmax, "b_argmax");
+
+        ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    double max_nmse_err() override {
+        return 0.0;
+    }
+};
+
 // GGML_OP_REPEAT
 struct test_repeat : public test_case {
     const ggml_type type;
@@ -3263,6 +3335,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
 
+    test_cases.emplace_back(new test_argmax());
+    test_cases.emplace_back(new test_count_equal());
+
     for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
         test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
         test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
@@ -3281,8 +3356,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {0, 2, 1, 3})); // dup by rows
     test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {1, 0, 2, 3}));
     test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {1, 0, 2, 3})); // dup dst not-contiguous
-    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3}));
-    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {1, 2, 0, 3}));
 
     for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
         test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));