]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml/ex: calculate accuracy in graph, adapt MNIST (ggml/980)
authorJohannes Gäßler <redacted>
Thu, 3 Oct 2024 15:29:59 +0000 (17:29 +0200)
committerGeorgi Gerganov <redacted>
Sat, 5 Oct 2024 12:23:51 +0000 (15:23 +0300)
ggml/include/ggml.h
ggml/src/ggml-cuda.cu
ggml/src/ggml-cuda/argmax.cu [new file with mode: 0644]
ggml/src/ggml-cuda/argmax.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/count-equal.cu [new file with mode: 0644]
ggml/src/ggml-cuda/count-equal.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/fattn-tile-f16.cu
ggml/src/ggml-cuda/fattn-vec-f16.cuh
ggml/src/ggml.c

index ce3d92cb2e0f060dc9dfc39ea3e16b0943a7fb67..128c63e68b3a7c10aa7e654438bc9616ea2e69de 100644 (file)
@@ -466,6 +466,7 @@ extern "C" {
         GGML_OP_SUM_ROWS,
         GGML_OP_MEAN,
         GGML_OP_ARGMAX,
+        GGML_OP_COUNT_EQUAL,
         GGML_OP_REPEAT,
         GGML_OP_REPEAT_BACK,
         GGML_OP_CONCAT,
@@ -1004,6 +1005,12 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // count number of equal elements in a and b
+    GGML_API struct ggml_tensor * ggml_count_equal(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     // if a is the same shape as b, and a is not parameter, return a
     // otherwise, return a new tensor: repeat(a) to fit in b
     GGML_API struct ggml_tensor * ggml_repeat(
index 6efdab14c36193c2d0b9735975e2cec912d77a9a..77444dc41d7dac8a4b685f6d1eec21b5073898f1 100644 (file)
@@ -5,12 +5,14 @@
 #include "ggml-cuda/common.cuh"
 #include "ggml-cuda/acc.cuh"
 #include "ggml-cuda/arange.cuh"
+#include "ggml-cuda/argmax.cuh"
 #include "ggml-cuda/argsort.cuh"
 #include "ggml-cuda/binbcast.cuh"
 #include "ggml-cuda/clamp.cuh"
 #include "ggml-cuda/concat.cuh"
 #include "ggml-cuda/conv-transpose-1d.cuh"
 #include "ggml-cuda/convert.cuh"
+#include "ggml-cuda/count-equal.cuh"
 #include "ggml-cuda/cpy.cuh"
 #include "ggml-cuda/cross-entropy-loss.cuh"
 #include "ggml-cuda/diagmask.cuh"
@@ -2178,6 +2180,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
     }
 
     switch (dst->op) {
+        case GGML_OP_ARGMAX:
+            ggml_cuda_argmax(ctx, dst);
+            break;
+        case GGML_OP_COUNT_EQUAL:
+            ggml_cuda_count_equal(ctx, dst);
+            break;
         case GGML_OP_REPEAT:
             ggml_cuda_op_repeat(ctx, dst);
             break;
@@ -2929,6 +2937,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 return false;
             } break;
         case GGML_OP_DUP:
+            {
+                ggml_type src0_type = op->src[0]->type;
+                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+            } break;
+        case GGML_OP_ARGMAX:
+        case GGML_OP_COUNT_EQUAL:
+            {
+                return true;
+            } break;
         case GGML_OP_REPEAT:
             {
                 ggml_type src0_type = op->src[0]->type;
diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu
new file mode 100644 (file)
index 0000000..aab04ec
--- /dev/null
@@ -0,0 +1,79 @@
+#include "common.cuh"
+#include "argmax.cuh"
+#include "sum.cuh"
+
+#include <cstdint>
+
+static __global__ void argmax_f32(
+    const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) {
+
+    int argmax_thread = 0;
+    const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE;
+
+#pragma unroll
+    for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) {
+        const int64_t row = row0 + row1;
+
+        if (row >= nrows) {
+            break;
+        }
+
+        float maxval = -FLT_MAX;
+        int   argmax = -1;
+
+        for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) {
+            const float val        = x[row*ncols + col];
+            const int   bigger     = val > maxval;
+            const int   not_bigger = bigger ^ 0x00000001;
+
+            maxval = maxval*not_bigger + val*bigger;
+            argmax = argmax*not_bigger + col*bigger;
+        }
+
+#pragma unroll
+        for (int mask = 16; mask > 0; mask >>= 1) {
+            const float val        = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE);
+            const int   col        = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE);
+            const int   bigger     = val > maxval;
+            const int   not_bigger = bigger ^ 0x00000001;
+
+            maxval = maxval*not_bigger + val*bigger;
+            argmax = argmax*not_bigger + col*bigger;
+        }
+
+        const int store = row1 == threadIdx.x;
+        argmax_thread += store*argmax;
+    }
+
+    const int row = row0 + threadIdx.x;
+
+    if (row >= nrows) {
+        return;
+    }
+
+    dst[row] = argmax_thread;
+}
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    const int64_t ne00  = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    const float * src0_d = (const float *) src0->data;
+    int32_t     * dst_d  = (int32_t     *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE;
+
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(num_blocks, 1, 1);
+
+    argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00, nrows);
+}
diff --git a/ggml/src/ggml-cuda/argmax.cuh b/ggml/src/ggml-cuda/argmax.cuh
new file mode 100644 (file)
index 0000000..5b7223a
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 6a4bcdba095736ae98396566936f7a024b0a02c7..dd203fcded3aafbd761c54125e38922edcc2b218 100644 (file)
@@ -175,6 +175,18 @@ static __device__ void no_device_code(
 #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
 #endif // __CUDA_ARCH__
 
+static __device__ __forceinline__ int warp_reduce_sum(int x) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
+    return __reduce_add_sync(0xffffffff, x);
+#else
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    }
+    return x;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
+}
+
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
 #pragma unroll
     for (int mask = 16; mask > 0; mask >>= 1) {
diff --git a/ggml/src/ggml-cuda/count-equal.cu b/ggml/src/ggml-cuda/count-equal.cu
new file mode 100644 (file)
index 0000000..ffb053b
--- /dev/null
@@ -0,0 +1,64 @@
+#include "common.cuh"
+#include "count-equal.cuh"
+
+#include <cstdint>
+
+template <typename T>
+static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) {
+    const int64_t i0 = (int64_t) blockIdx.x*dk;
+    const int64_t i1 = min(i0 + dk, k);
+
+    int nequal = 0;
+
+    for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) {
+        const T xi = x[i];
+        const T yi = y[i];
+        nequal += xi == yi;
+    }
+
+    nequal = warp_reduce_sum(nequal);
+
+    if (threadIdx.x != 0) {
+        return;
+    }
+
+    atomicAdd((int *) dst, nequal);
+}
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == src1->type);
+    GGML_ASSERT( dst->type == GGML_TYPE_I64);
+
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    int64_t * dst_d  = (int64_t *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
+    const int64_t dne = GGML_PAD(ne / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
+
+    CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));
+
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1);
+
+    switch (src0->type) {
+        case GGML_TYPE_I32: {
+            const int * src0_d = (const int *) src0->data;
+            const int * src1_d = (const int *) src1->data;
+            count_equal<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_d, dne, ne);
+        } break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+}
diff --git a/ggml/src/ggml-cuda/count-equal.cuh b/ggml/src/ggml-cuda/count-equal.cuh
new file mode 100644 (file)
index 0000000..8467da7
--- /dev/null
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 342f2eb665312d7dd93003a60bcd195ac032ac2e..5af02c7ecbed75ab6f271d04eca360785dd4c584 100644 (file)
@@ -259,7 +259,7 @@ static __global__ void flash_attn_tile_ext_f16(
         }
 
         half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
-        kqsum_j = warp_reduce_sum(kqsum_j);
+        kqsum_j = warp_reduce_sum((float)kqsum_j);
 
 #pragma unroll
         for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
index 448a9a9054cca112cdfd0b0c4de1e1c6152a38af..2ed6509acb82d2b33a0f3dede3532e7b25a81a2a 100644 (file)
@@ -196,7 +196,7 @@ static __global__ void flash_attn_vec_ext_f16(
 #pragma unroll
             for (int j = 0; j < ncols; ++j) {
                 half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
-                sum = warp_reduce_sum(sum);
+                sum = warp_reduce_sum((float)sum);
 
                 if (use_logit_softcap) {
                     sum = logit_softcap*tanhf(sum);
@@ -265,7 +265,7 @@ static __global__ void flash_attn_vec_ext_f16(
 
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
-        kqsum[j] = warp_reduce_sum(kqsum[j]);
+        kqsum[j] = warp_reduce_sum((float)kqsum[j]);
         if (threadIdx.x == 0) {
             kqsum_shared[j][threadIdx.y] = kqsum[j];
         }
@@ -280,7 +280,7 @@ static __global__ void flash_attn_vec_ext_f16(
         }
 
         kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
-        kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
+        kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
 
         half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
         if (parallel_blocks == 1) {
index fd411a64b1868b412b61f98783b415fa80884477..5038b689b1e0b278255280aac802bf52318f7303 100644 (file)
@@ -2957,6 +2957,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "SUM_ROWS",
     "MEAN",
     "ARGMAX",
+    "COUNT_EQUAL",
     "REPEAT",
     "REPEAT_BACK",
     "CONCAT",
@@ -3030,7 +3031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3051,6 +3052,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "Σx_k",
     "Σx/n",
     "argmax(x)",
+    "count_equal(x)",
     "repeat(x)",
     "repeat_back(x)",
     "concat(x, y)",
@@ -3124,7 +3126,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
+static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -5185,6 +5187,23 @@ struct ggml_tensor * ggml_argmax(
     return result;
 }
 
+// ggml_count_equal
+
+struct ggml_tensor * ggml_count_equal(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b) {
+    GGML_ASSERT(ggml_are_same_shape(a, b));
+
+    struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1);
+
+    result->op     = GGML_OP_COUNT_EQUAL;
+    result->src[0] = a;
+    result->src[1] = b;
+
+    return result;
+}
+
 // ggml_repeat
 
 struct ggml_tensor * ggml_repeat(
@@ -10772,6 +10791,86 @@ static void ggml_compute_forward_argmax(
     }
 }
 
+// ggml_compute_forward_count_equal
+
+static void ggml_compute_forward_count_equal_i32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    GGML_ASSERT(src0->type == GGML_TYPE_I32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(ggml_are_same_shape(src0, src1));
+    GGML_ASSERT(ggml_is_scalar(dst));
+    GGML_ASSERT(dst->type == GGML_TYPE_I64);
+
+    const int64_t nr = ggml_nrows(src0);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    int64_t * sums = (int64_t *) params->wdata;
+    int64_t sum_thread = 0;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 =  ir                        / (ne02*ne01);
+        const int64_t i02 = (ir - i03*ne03)            /       ne01;
+        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
+
+        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
+        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
+
+        for (int64_t i00 = 0; i00 < ne00; ++i00) {
+            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
+            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
+
+            sum_thread += val0 == val1;
+        }
+    }
+    if (ith != 0) {
+        sums[ith] = sum_thread;
+    }
+    ggml_barrier(params->threadpool);
+
+    if (ith != 0) {
+        return;
+    }
+
+    for (int ith_other = 1; ith_other < nth; ++ith_other) {
+        sum_thread += sums[ith_other];
+    }
+    *((int64_t *) dst->data) = sum_thread;
+}
+
+static void ggml_compute_forward_count_equal(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_count_equal_i32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_repeat
 
 static void ggml_compute_forward_repeat_f32(
@@ -17146,6 +17245,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_argmax(params, tensor);
             } break;
+        case GGML_OP_COUNT_EQUAL:
+            {
+                ggml_compute_forward_count_equal(params, tensor);
+            } break;
         case GGML_OP_REPEAT:
             {
                 ggml_compute_forward_repeat(params, tensor);
@@ -17896,6 +17999,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             } break;
         case GGML_OP_MEAN:
         case GGML_OP_ARGMAX:
+        case GGML_OP_COUNT_EQUAL:
             {
                 GGML_ABORT("fatal error"); // TODO: implement
             }
@@ -18669,6 +18773,10 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
     for (int i = 0; i < gf->n_nodes; ++i) {
         struct ggml_tensor * node = gf->nodes[i];
 
+        if (node->type == GGML_TYPE_I32) {
+            continue;
+        }
+
         bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
         bool ignore_src[GGML_MAX_SRC] = {false};
         switch (node->op) {
@@ -19072,6 +19180,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
         case GGML_OP_ARGMAX:
+            {
+                n_tasks = 1;
+            } break;
+        case GGML_OP_COUNT_EQUAL:
+            {
+                n_tasks = n_threads;
+            } break;
         case GGML_OP_REPEAT:
         case GGML_OP_REPEAT_BACK:
         case GGML_OP_LEAKY_RELU:
@@ -19570,6 +19685,10 @@ struct ggml_cplan ggml_graph_plan(
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
                     }
                 } break;
+            case GGML_OP_COUNT_EQUAL:
+                {
+                    cur = ggml_type_size(node->type)*n_tasks;
+                } break;
             case GGML_OP_MUL_MAT:
                 {
                     const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;