From: Johannes Gäßler Date: Thu, 3 Oct 2024 15:29:59 +0000 (+0200) Subject: ggml/ex: calculate accuracy in graph, adapt MNIST (#980) X-Git-Tag: upstream/0.0.1642~318 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=e5c233e5edbfcfa1d808b9293de9065035c40751;p=pkg%2Fggml%2Fsources%2Fggml ggml/ex: calculate accuracy in graph, adapt MNIST (#980) --- diff --git a/examples/mnist/mnist-common.cpp b/examples/mnist/mnist-common.cpp index 4b7374e0..c70dfd7f 100644 --- a/examples/mnist/mnist-common.cpp +++ b/examples/mnist/mnist-common.cpp @@ -480,24 +480,47 @@ void mnist_model_build(mnist_model & model, const int nbatch_logical, const int GGML_ASSERT(model.loss->ne[1] == 1); GGML_ASSERT(model.loss->ne[2] == 1); GGML_ASSERT(model.loss->ne[3] == 1); + + model.pred = ggml_argmax(model.ctx_compute, model.logits); + ggml_set_name(model.pred, "predictions"); + ggml_set_output(model.pred); + GGML_ASSERT(model.pred->type == GGML_TYPE_I32); + GGML_ASSERT(model.pred->ne[0] == model.nbatch_physical); + GGML_ASSERT(model.pred->ne[1] == 1); + GGML_ASSERT(model.pred->ne[2] == 1); + GGML_ASSERT(model.pred->ne[3] == 1); + + model.acc_count = ggml_count_equal(model.ctx_compute, model.pred, ggml_argmax(model.ctx_compute, model.labels)); + ggml_set_name(model.acc_count, "accuracy_count"); + ggml_set_output(model.acc_count); + GGML_ASSERT(model.acc_count->type == GGML_TYPE_I64); + GGML_ASSERT(model.acc_count->ne[0] == 1); + GGML_ASSERT(model.acc_count->ne[1] == 1); + GGML_ASSERT(model.acc_count->ne[2] == 1); + GGML_ASSERT(model.acc_count->ne[3] == 1); } mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, const float * labels, const int nex) { mnist_eval_result result; struct ggml_cgraph * gf = ggml_new_graph(model.ctx_compute); + // The outputs are diverging branches of the graphs, therefore multiple calls to ggml_build_forward_expand are needed. ggml_build_forward_expand(gf, model.loss); + ggml_build_forward_expand(gf, model.pred); + ggml_build_forward_expand(gf, model.acc_count); model.buf_compute = ggml_backend_alloc_ctx_tensors(model.ctx_compute, model.backend); { const int64_t t_start_us = ggml_time_us(); - float loss; - std::vector logits(model.nbatch_physical*MNIST_NCLASSES); + float tmp_loss; + std::vector tmp_pred(model.nbatch_physical); + int64_t tmp_acc_count; - GGML_ASSERT(sizeof(loss) == ggml_nbytes(model.loss)); - GGML_ASSERT(logits.size() == ggml_nelements(model.logits)); + GGML_ASSERT(sizeof(tmp_loss) == ggml_nbytes(model.loss)); + GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred)); + GGML_ASSERT(sizeof(tmp_acc_count) == ggml_nbytes(model.acc_count)); GGML_ASSERT(nex % model.nbatch_physical == 0); for (int iex0 = 0; iex0 < nex; iex0 += model.nbatch_physical) { @@ -506,15 +529,14 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co ggml_backend_graph_compute(model.backend, gf); - ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss)); - ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits)); - - result.loss.push_back(loss); + ggml_backend_tensor_get(model.loss, &tmp_loss, 0, ggml_nbytes(model.loss)); + ggml_backend_tensor_get(model.pred, tmp_pred.data(), 0, ggml_nbytes(model.pred)); + ggml_backend_tensor_get(model.acc_count, &tmp_acc_count, 0, ggml_nbytes(model.acc_count)); - for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) { - const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES; - result.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb); - } + result.loss.push_back(tmp_loss); + result.pred.insert(result.pred.end(), tmp_pred.begin(), tmp_pred.end()); + result.ncorrect += tmp_acc_count; + result.ntotal += model.nbatch_physical; } const int64_t t_total_us = ggml_time_us() - t_start_us; @@ -530,13 +552,18 @@ mnist_eval_result mnist_model_eval(mnist_model & model, const float * images, co void mnist_model_train(mnist_model & model, const float * images, const float * labels, const int nex, const int nepoch, const float val_split) { const int64_t t_start_us = ggml_time_us(); + const bool accumulate = model.nbatch_physical != model.nbatch_logical; + // gf == graph forward, forward pass only. struct ggml_cgraph * gf = ggml_new_graph_custom(model.ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + // The outputs are diverging branches of the graphs, therefore multiple calls to ggml_build_forward_expand are needed. ggml_build_forward_expand(gf, model.loss); + ggml_build_forward_expand(gf, model.pred); + ggml_build_forward_expand(gf, model.acc_count); // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf); - ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true); + ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, accumulate); // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad); @@ -551,10 +578,16 @@ void mnist_model_train(mnist_model & model, const float * images, const float * fprintf(stderr, "%s: epoch %02d start...", __func__, epoch); const int64_t t_start_us = ggml_time_us(); - float loss; - std::vector logits(model.nbatch_physical*MNIST_NCLASSES); int iex0 = 0; + float tmp_loss; + std::vector tmp_pred(model.nbatch_physical); + int64_t tmp_acc_count; + + GGML_ASSERT(sizeof(tmp_loss) == ggml_nbytes(model.loss)); + GGML_ASSERT(sizeof(tmp_pred[0])*tmp_pred.size() == ggml_nbytes(model.pred)); + GGML_ASSERT(sizeof(tmp_acc_count) == ggml_nbytes(model.acc_count)); + mnist_eval_result result_train; for (; iex0 < iex_split; iex0 += model.nbatch_physical) { ggml_backend_tensor_set(model.images, images + iex0*MNIST_NINPUT, 0, ggml_nbytes(model.images)); @@ -570,15 +603,14 @@ void mnist_model_train(mnist_model & model, const float * images, const float * ggml_graph_reset(gb_grad); // Set gradients to zero, do not reset optimizer. } - ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss)); - ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits)); - - result_train.loss.push_back(loss); + ggml_backend_tensor_get(model.loss, &tmp_loss, 0, ggml_nbytes(model.loss)); + ggml_backend_tensor_get(model.pred, tmp_pred.data(), 0, ggml_nbytes(model.pred)); + ggml_backend_tensor_get(model.acc_count, &tmp_acc_count, 0, ggml_nbytes(model.acc_count)); - for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) { - const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES; - result_train.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb); - } + result_train.loss.push_back(tmp_loss); + result_train.pred.insert(result_train.pred.end(), tmp_pred.begin(), tmp_pred.end()); + result_train.ncorrect += tmp_acc_count; + result_train.ntotal += model.nbatch_physical; } mnist_eval_result result_val; @@ -588,20 +620,19 @@ void mnist_model_train(mnist_model & model, const float * images, const float * ggml_backend_graph_compute(model.backend, gf); // For the validation set, only the forward pass is needed. - ggml_backend_tensor_get(model.loss, &loss, 0, ggml_nbytes(model.loss)); - ggml_backend_tensor_get(model.logits, logits.data(), 0, ggml_nbytes(model.logits)); - - result_val.loss.push_back(loss); + ggml_backend_tensor_get(model.loss, &tmp_loss, 0, ggml_nbytes(model.loss)); + ggml_backend_tensor_get(model.pred, tmp_pred.data(), 0, ggml_nbytes(model.pred)); + ggml_backend_tensor_get(model.acc_count, &tmp_acc_count, 0, ggml_nbytes(model.acc_count)); - for (int iexb = 0; iexb < model.nbatch_physical; ++iexb) { - const float * logits_iexb = logits.data() + iexb*MNIST_NCLASSES; - result_val.pred.push_back(std::max_element(logits_iexb, logits_iexb + MNIST_NCLASSES) - logits_iexb); - } + result_val.loss.push_back(tmp_loss); + result_val.pred.insert(result_val.pred.end(), tmp_pred.begin(), tmp_pred.end()); + result_val.ncorrect += tmp_acc_count; + result_val.ntotal += model.nbatch_physical; } { const double loss_mean = mnist_loss(result_train).first; - const double percent_correct = 100.0 * mnist_accuracy(result_train, labels + 0*MNIST_NCLASSES).first; + const double percent_correct = 100.0 * mnist_accuracy(result_train).first; const int64_t t_epoch_us = ggml_time_us() - t_start_us; const double t_epoch_s = 1e-6*t_epoch_us; @@ -610,7 +641,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float * if (iex_split < nex) { const std::pair loss = mnist_loss(result_val); - const std::pair acc = mnist_accuracy(result_val, labels + iex_split*MNIST_NCLASSES); + const std::pair acc = mnist_accuracy(result_val); fprintf(stderr, ", val_loss=%.6lf+-%.6lf, val_acc=%.2f+-%.2f%%", loss.first, loss.second, 100.0*acc.first, 100.0*acc.second); } @@ -668,7 +699,7 @@ void mnist_model_save(mnist_model & model, const std::string & fname) { std::pair mnist_loss(const mnist_eval_result & result) { const size_t nbatches = result.loss.size(); - GGML_ASSERT(nbatches >= 1); + GGML_ASSERT(nbatches >= 2); double sum = 0.0; double sum_squared = 0.0; @@ -684,20 +715,12 @@ std::pair mnist_loss(const mnist_eval_result & result) { return std::make_pair(mean, uncertainty); } -std::pair mnist_accuracy(const mnist_eval_result & result, const float * labels) { - const size_t nex = result.pred.size(); - GGML_ASSERT(nex >= 1); - - size_t ncorrect = 0; - for (size_t iex = 0; iex < nex; ++iex) { - const float * labels_iex = labels + iex*MNIST_NCLASSES; - const int32_t label = std::max_element(labels_iex, labels_iex + MNIST_NCLASSES) - labels_iex; - - ncorrect += result.pred[iex] == label; - } +std::pair mnist_accuracy(const mnist_eval_result & result) { + GGML_ASSERT(result.ntotal >= result.ncorrect); + GGML_ASSERT(result.ntotal >= 2); - const double fraction_correct = ((double) ncorrect) / ((double) nex); - const double uncertainty = sqrt(fraction_correct * (1.0 - fraction_correct) / (nex - 1)); + const double fraction_correct = ((double) result.ncorrect) / ((double) result.ntotal); + const double uncertainty = sqrt(fraction_correct * (1.0 - fraction_correct) / (result.ncorrect - 1)); return std::make_pair(fraction_correct, uncertainty); } diff --git a/examples/mnist/mnist-common.h b/examples/mnist/mnist-common.h index 72638a40..a6239a42 100644 --- a/examples/mnist/mnist-common.h +++ b/examples/mnist/mnist-common.h @@ -31,11 +31,13 @@ struct mnist_model { int nbatch_logical; int nbatch_physical; - struct ggml_tensor * images = nullptr; - struct ggml_tensor * labels = nullptr; - struct ggml_tensor * logits = nullptr; - struct ggml_tensor * probs = nullptr; - struct ggml_tensor * loss = nullptr; + struct ggml_tensor * images = nullptr; + struct ggml_tensor * labels = nullptr; + struct ggml_tensor * logits = nullptr; + struct ggml_tensor * probs = nullptr; + struct ggml_tensor * loss = nullptr; + struct ggml_tensor * pred = nullptr; + struct ggml_tensor * acc_count = nullptr; struct ggml_tensor * fc1_weight = nullptr; struct ggml_tensor * fc1_bias = nullptr; @@ -108,6 +110,8 @@ struct mnist_eval_result { std::vector loss; std::vector pred; + int64_t ncorrect = 0; + int64_t ntotal = 0; }; bool mnist_image_load(const std::string & fname, float * buf, const int nex); @@ -124,4 +128,4 @@ void mnist_model_train(mnist_model & model, const float * images, c void mnist_model_save(mnist_model & model, const std::string & fname); std::pair mnist_loss(const mnist_eval_result & result); -std::pair mnist_accuracy(const mnist_eval_result & result, const float * labels); +std::pair mnist_accuracy(const mnist_eval_result & result); diff --git a/examples/mnist/mnist-eval.cpp b/examples/mnist/mnist-eval.cpp index ae2cadbc..a5f9676f 100644 --- a/examples/mnist/mnist-eval.cpp +++ b/examples/mnist/mnist-eval.cpp @@ -55,7 +55,7 @@ int main(int argc, char ** argv) { std::pair result_loss = mnist_loss(result_eval); fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second); - std::pair result_acc = mnist_accuracy(result_eval, labels.data()); + std::pair result_acc = mnist_accuracy(result_eval); fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second); return 0; @@ -79,7 +79,7 @@ int main(int argc, char ** argv) { std::pair result_loss = mnist_loss(result_eval); fprintf(stdout, "%s: test_loss=%.6lf+-%.6lf\n", __func__, result_loss.first, result_loss.second); - std::pair result_acc = mnist_accuracy(result_eval, labels.data()); + std::pair result_acc = mnist_accuracy(result_eval); fprintf(stdout, "%s: test_acc=%.2lf+-%.2lf%%\n", __func__, 100.0*result_acc.first, 100.0*result_acc.second); return 0; diff --git a/include/ggml.h b/include/ggml.h index ce3d92cb..128c63e6 100644 --- a/include/ggml.h +++ b/include/ggml.h @@ -466,6 +466,7 @@ extern "C" { GGML_OP_SUM_ROWS, GGML_OP_MEAN, GGML_OP_ARGMAX, + GGML_OP_COUNT_EQUAL, GGML_OP_REPEAT, GGML_OP_REPEAT_BACK, GGML_OP_CONCAT, @@ -1004,6 +1005,12 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // count number of equal elements in a and b + GGML_API struct ggml_tensor * ggml_count_equal( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // if a is the same shape as b, and a is not parameter, return a // otherwise, return a new tensor: repeat(a) to fit in b GGML_API struct ggml_tensor * ggml_repeat( diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 6efdab14..77444dc4 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -5,12 +5,14 @@ #include "ggml-cuda/common.cuh" #include "ggml-cuda/acc.cuh" #include "ggml-cuda/arange.cuh" +#include "ggml-cuda/argmax.cuh" #include "ggml-cuda/argsort.cuh" #include "ggml-cuda/binbcast.cuh" #include "ggml-cuda/clamp.cuh" #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/convert.cuh" +#include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" @@ -2178,6 +2180,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } switch (dst->op) { + case GGML_OP_ARGMAX: + ggml_cuda_argmax(ctx, dst); + break; + case GGML_OP_COUNT_EQUAL: + ggml_cuda_count_equal(ctx, dst); + break; case GGML_OP_REPEAT: ggml_cuda_op_repeat(ctx, dst); break; @@ -2929,6 +2937,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return false; } break; case GGML_OP_DUP: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } break; + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + { + return true; + } break; case GGML_OP_REPEAT: { ggml_type src0_type = op->src[0]->type; diff --git a/src/ggml-cuda/argmax.cu b/src/ggml-cuda/argmax.cu new file mode 100644 index 00000000..aab04eca --- /dev/null +++ b/src/ggml-cuda/argmax.cu @@ -0,0 +1,79 @@ +#include "common.cuh" +#include "argmax.cuh" +#include "sum.cuh" + +#include + +static __global__ void argmax_f32( + const float * x, int32_t * dst, const int64_t ncols, const int64_t nrows) { + + int argmax_thread = 0; + const int64_t row0 = (int64_t)blockIdx.x*WARP_SIZE; + +#pragma unroll + for (int64_t row1 = 0; row1 < WARP_SIZE; ++row1) { + const int64_t row = row0 + row1; + + if (row >= nrows) { + break; + } + + float maxval = -FLT_MAX; + int argmax = -1; + + for (int32_t col = threadIdx.x; col < ncols; col += WARP_SIZE) { + const float val = x[row*ncols + col]; + const int bigger = val > maxval; + const int not_bigger = bigger ^ 0x00000001; + + maxval = maxval*not_bigger + val*bigger; + argmax = argmax*not_bigger + col*bigger; + } + +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, mask, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, mask, WARP_SIZE); + const int bigger = val > maxval; + const int not_bigger = bigger ^ 0x00000001; + + maxval = maxval*not_bigger + val*bigger; + argmax = argmax*not_bigger + col*bigger; + } + + const int store = row1 == threadIdx.x; + argmax_thread += store*argmax; + } + + const int row = row0 + threadIdx.x; + + if (row >= nrows) { + return; + } + + dst[row] = argmax_thread; +} + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + int32_t * dst_d = (int32_t *) dst->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t num_blocks = (nrows + WARP_SIZE - 1) / WARP_SIZE; + + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(num_blocks, 1, 1); + + argmax_f32<<>>(src0_d, dst_d, ne00, nrows); +} diff --git a/src/ggml-cuda/argmax.cuh b/src/ggml-cuda/argmax.cuh new file mode 100644 index 00000000..5b7223ad --- /dev/null +++ b/src/ggml-cuda/argmax.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-cuda/common.cuh b/src/ggml-cuda/common.cuh index 6a4bcdba..dd203fcd 100644 --- a/src/ggml-cuda/common.cuh +++ b/src/ggml-cuda/common.cuh @@ -175,6 +175,18 @@ static __device__ void no_device_code( #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #endif // __CUDA_ARCH__ +static __device__ __forceinline__ int warp_reduce_sum(int x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE + return __reduce_add_sync(0xffffffff, x); +#else +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE +} + static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { diff --git a/src/ggml-cuda/count-equal.cu b/src/ggml-cuda/count-equal.cu new file mode 100644 index 00000000..ffb053b1 --- /dev/null +++ b/src/ggml-cuda/count-equal.cu @@ -0,0 +1,64 @@ +#include "common.cuh" +#include "count-equal.cuh" + +#include + +template +static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) { + const int64_t i0 = (int64_t) blockIdx.x*dk; + const int64_t i1 = min(i0 + dk, k); + + int nequal = 0; + + for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) { + const T xi = x[i]; + const T yi = y[i]; + nequal += xi == yi; + } + + nequal = warp_reduce_sum(nequal); + + if (threadIdx.x != 0) { + return; + } + + atomicAdd((int *) dst, nequal); +} + +void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT( dst->type == GGML_TYPE_I64); + + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + int64_t * dst_d = (int64_t *) dst->data; + + cudaStream_t stream = ctx.stream(); + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int"); + const int64_t dne = GGML_PAD(ne / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE); + + CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream)); + + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1); + + switch (src0->type) { + case GGML_TYPE_I32: { + const int * src0_d = (const int *) src0->data; + const int * src1_d = (const int *) src1->data; + count_equal<<>>(src0_d, src1_d, dst_d, dne, ne); + } break; + default: + GGML_ASSERT(false); + break; + } +} diff --git a/src/ggml-cuda/count-equal.cuh b/src/ggml-cuda/count-equal.cuh new file mode 100644 index 00000000..8467da79 --- /dev/null +++ b/src/ggml-cuda/count-equal.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128 + +void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/src/ggml-cuda/fattn-tile-f16.cu b/src/ggml-cuda/fattn-tile-f16.cu index 342f2eb6..5af02c7e 100644 --- a/src/ggml-cuda/fattn-tile-f16.cu +++ b/src/ggml-cuda/fattn-tile-f16.cu @@ -259,7 +259,7 @@ static __global__ void flash_attn_tile_ext_f16( } half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); - kqsum_j = warp_reduce_sum(kqsum_j); + kqsum_j = warp_reduce_sum((float)kqsum_j); #pragma unroll for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { diff --git a/src/ggml-cuda/fattn-vec-f16.cuh b/src/ggml-cuda/fattn-vec-f16.cuh index 448a9a90..2ed6509a 100644 --- a/src/ggml-cuda/fattn-vec-f16.cuh +++ b/src/ggml-cuda/fattn-vec-f16.cuh @@ -196,7 +196,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum(sum); + sum = warp_reduce_sum((float)sum); if (use_logit_softcap) { sum = logit_softcap*tanhf(sum); @@ -265,7 +265,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum(kqsum[j]); + kqsum[j] = warp_reduce_sum((float)kqsum[j]); if (threadIdx.x == 0) { kqsum_shared[j][threadIdx.y] = kqsum[j]; } @@ -280,7 +280,7 @@ static __global__ void flash_attn_vec_ext_f16( } kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); + kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]); half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); if (parallel_blocks == 1) { diff --git a/src/ggml.c b/src/ggml.c index fd411a64..5038b689 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -2957,6 +2957,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SUM_ROWS", "MEAN", "ARGMAX", + "COUNT_EQUAL", "REPEAT", "REPEAT_BACK", "CONCAT", @@ -3030,7 +3031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3051,6 +3052,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "Σx_k", "Σx/n", "argmax(x)", + "count_equal(x)", "repeat(x)", "repeat_back(x)", "concat(x, y)", @@ -3124,7 +3126,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); +static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5185,6 +5187,23 @@ struct ggml_tensor * ggml_argmax( return result; } +// ggml_count_equal + +struct ggml_tensor * ggml_count_equal( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1); + + result->op = GGML_OP_COUNT_EQUAL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_repeat struct ggml_tensor * ggml_repeat( @@ -10772,6 +10791,86 @@ static void ggml_compute_forward_argmax( } } +// ggml_compute_forward_count_equal + +static void ggml_compute_forward_count_equal_i32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT(src0->type == GGML_TYPE_I32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_scalar(dst)); + GGML_ASSERT(dst->type == GGML_TYPE_I64); + + const int64_t nr = ggml_nrows(src0); + + const int ith = params->ith; + const int nth = params->nth; + + int64_t * sums = (int64_t *) params->wdata; + int64_t sum_thread = 0; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02*ne01); + const int64_t i02 = (ir - i03*ne03) / ne01; + const int64_t i01 = ir - i03*ne03 - i02*ne02; + + const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01; + const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11; + + for (int64_t i00 = 0; i00 < ne00; ++i00) { + const int32_t val0 = *((const int32_t *) (data0 + i00*nb00)); + const int32_t val1 = *((const int32_t *) (data1 + i00*nb10)); + + sum_thread += val0 == val1; + } + } + if (ith != 0) { + sums[ith] = sum_thread; + } + ggml_barrier(params->threadpool); + + if (ith != 0) { + return; + } + + for (int ith_other = 1; ith_other < nth; ++ith_other) { + sum_thread += sums[ith_other]; + } + *((int64_t *) dst->data) = sum_thread; +} + +static void ggml_compute_forward_count_equal( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_I32: + { + ggml_compute_forward_count_equal_i32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_repeat static void ggml_compute_forward_repeat_f32( @@ -17146,6 +17245,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_argmax(params, tensor); } break; + case GGML_OP_COUNT_EQUAL: + { + ggml_compute_forward_count_equal(params, tensor); + } break; case GGML_OP_REPEAT: { ggml_compute_forward_repeat(params, tensor); @@ -17896,6 +17999,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_MEAN: case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: { GGML_ABORT("fatal error"); // TODO: implement } @@ -18669,6 +18773,10 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * for (int i = 0; i < gf->n_nodes; ++i) { struct ggml_tensor * node = gf->nodes[i]; + if (node->type == GGML_TYPE_I32) { + continue; + } + bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM; bool ignore_src[GGML_MAX_SRC] = {false}; switch (node->op) { @@ -19072,6 +19180,13 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_ARGMAX: + { + n_tasks = 1; + } break; + case GGML_OP_COUNT_EQUAL: + { + n_tasks = n_threads; + } break; case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_LEAKY_RELU: @@ -19570,6 +19685,10 @@ struct ggml_cplan ggml_graph_plan( cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; } } break; + case GGML_OP_COUNT_EQUAL: + { + cur = ggml_type_size(node->type)*n_tasks; + } break; case GGML_OP_MUL_MAT: { const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 95d983aa..120eba10 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -116,6 +116,11 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) { // This is going to create some weird integers though. ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor)); + } else if (tensor->type == GGML_TYPE_I64) { + // Integers with a size of 8 bytes can be set by mirroring the float data, the specific values are again not really meaningful. + const size_t nbytes_half = ggml_nbytes(tensor)/2; + ggml_backend_tensor_set(tensor, data.data(), 0*nbytes_half, nbytes_half); + ggml_backend_tensor_set(tensor, data.data(), 1*nbytes_half, nbytes_half); } else { GGML_ABORT("fatal error"); } @@ -145,6 +150,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) { tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i])); } else if (t->type == GGML_TYPE_F32) { tv.push_back(*(float *) &buf[i]); + } else if (t->type == GGML_TYPE_I64) { + tv.push_back((float)*(int64_t *) &buf[i]); } else if (t->type == GGML_TYPE_I32) { tv.push_back((float)*(int32_t *) &buf[i]); } else if (t->type == GGML_TYPE_I16) { @@ -1119,6 +1126,71 @@ struct test_get_rows : public test_case { } }; +// GGML_OP_ARGMAX +struct test_argmax : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_argmax(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 100, 1, 1}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_argmax(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + double max_nmse_err() override { + return 0.0; + } +}; + +// GGML_OP_COUNT_EQUAL +struct test_count_equal : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_count_equal(ggml_type type = GGML_TYPE_F32, + std::array ne = {4, 500, 1, 1}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * a_argmax = ggml_argmax(ctx, a); + ggml_set_name(a_argmax, "a_argmax"); + + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(b, "b"); + + ggml_tensor * b_argmax = ggml_argmax(ctx, a); + ggml_set_name(b_argmax, "b_argmax"); + + ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax); + ggml_set_name(out, "out"); + + return out; + } + + double max_nmse_err() override { + return 0.0; + } +}; + // GGML_OP_REPEAT struct test_repeat : public test_case { const ggml_type type; @@ -3263,6 +3335,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); + test_cases.emplace_back(new test_argmax()); + test_cases.emplace_back(new test_count_equal()); + for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1 test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1})); test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1})); @@ -3281,8 +3356,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {0, 2, 1, 3})); // dup by rows test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {1, 0, 2, 3})); test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {1, 0, 2, 3})); // dup dst not-contiguous - test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3})); - test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3})); + test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3})); + test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3})); for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) { test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));