From: Georgi Gerganov Date: Sun, 23 Jul 2023 14:51:29 +0000 (+0300) Subject: ggml : sync llama.cpp (#409) X-Git-Tag: upstream/0.0.1642~1301 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=27e2c5b63484bde71667d7a48c53a82d0434f620;p=pkg%2Fggml%2Fsources%2Fggml ggml : sync llama.cpp (#409) * ggml : sync llama.cpp ggml-ci * ggml : fix nullptr derefs in backward * ci : add mnist test, import/export graph * add op_params to ggml_graph_export/import ggml-ci * mnist : export/import op_params for testing purposes * mnist : fix f32 model generation test + instructions ggml-ci * ci : install python deps even for low-perf builds ggml-ci --------- Co-authored-by: Diego Devesa --- diff --git a/ci/run.sh b/ci/run.sh index 973a0fe3..abb43e76 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -1,4 +1,15 @@ #/bin/bash +# +# sample usage: +# +# mkdir tmp +# +# # CPU-only build +# bash ./ci/run.sh ./tmp/results ./tmp/mnt +# +# # with CUDA support +# GG_BUILD_CUDA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt +# if [ -z "$2" ]; then echo "usage: $0 " @@ -190,23 +201,57 @@ function gg_sum_mpt { gg_printf '```\n' } +# mnist + +function gg_run_mnist { + cd ${SRC} + + cd build-ci-release + + set -e + + mkdir -p models/mnist + python3 ../examples/mnist/convert-h5-to-ggml.py ../examples/mnist/models/mnist/mnist_model.state_dict + + model_f32="./models/mnist/ggml-model-f32.bin" + samples="../examples/mnist/models/mnist/t10k-images.idx3-ubyte" + + # first command runs and exports "mnist.ggml", the second command runs the exported model + + (time ./bin/mnist ${model_f32} ${samples} ) 2>&1 | tee -a $OUT/${ci}-mnist.log + (time ./bin/mnist-cpu ./mnist.ggml ${samples} ) 2>&1 | tee -a $OUT/${ci}-mnist.log + + set +e +} + +function gg_sum_mnist { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'MNIST\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '```\n' + gg_printf '%s\n' "$(cat $OUT/${ci}-mnist.log)" + gg_printf '```\n' +} + ## main if [ -z $GG_BUILD_LOW_PERF ]; then rm -rf ${SRC}/models-mnt - mnt_models=$(realpath ${MNT}/models) + mnt_models=${MNT}/models mkdir -p ${mnt_models} ln -sfn ${mnt_models} ${SRC}/models-mnt - - python3 -m pip install -r ${SRC}/requirements.txt fi +python3 -m pip install -r ${SRC}/requirements.txt + ret=0 test $ret -eq 0 && gg_run ctest_debug test $ret -eq 0 && gg_run ctest_release test $ret -eq 0 && gg_run gpt_2 +test $ret -eq 0 && gg_run mnist if [ -z $GG_BUILD_LOW_PERF ]; then test $ret -eq 0 && gg_run mpt diff --git a/examples/mnist/README.md b/examples/mnist/README.md index 0f2ed8c2..3bb436cb 100644 --- a/examples/mnist/README.md +++ b/examples/mnist/README.md @@ -41,8 +41,13 @@ mkdir build && cd build cmake .. make -j4 mnist +# Generate ggml model +mkdir -p models/mnist +python3 ../examples/mnist/convert-h5-to-ggml.py ../examples/mnist/models/mnist/mnist_model.state_dict + # Run the MNIST model -./bin/mnist ../examples/mnist/models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte + +./bin/mnist ./models/mnist/ggml-model-f32.bin ../examples/mnist/models/mnist/t10k-images.idx3-ubyte ``` For more information, checkout the corresponding programs in the [examples](examples) folder. diff --git a/examples/mnist/main-cpu.cpp b/examples/mnist/main-cpu.cpp index 2000c9aa..3e8bfe67 100644 --- a/examples/mnist/main-cpu.cpp +++ b/examples/mnist/main-cpu.cpp @@ -42,6 +42,9 @@ int mnist_eval( struct ggml_cgraph gfi = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval); + // param export/import test + GGML_ASSERT(ggml_graph_get_tensor(&gfi, "fc1_bias")->op_params[0] == 0xdeadbeef); + // allocate work context // needed during ggml_graph_compute() to allocate a work tensor static size_t buf_size = 128ull*1024*1024; // TODO diff --git a/examples/mnist/main.cpp b/examples/mnist/main.cpp index 5ff4ac20..33986e4e 100644 --- a/examples/mnist/main.cpp +++ b/examples/mnist/main.cpp @@ -119,6 +119,9 @@ bool mnist_model_load(const std::string & fname, mnist_model & model) { model.fc1_bias = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_hidden); fin.read(reinterpret_cast(model.fc1_bias->data), ggml_nbytes(model.fc1_bias)); ggml_set_name(model.fc1_bias, "fc1_bias"); + + // just for testing purposes, set some parameters to non-zero + model.fc1_bias->op_params[0] = 0xdeadbeef; } } diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 24856a25..871c85a8 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -199,6 +199,7 @@ #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 6 #define GGML_MAX_NAME 48 +#define GGML_MAX_OP_PARAMS 32 #define GGML_DEFAULT_N_THREADS 4 @@ -418,6 +419,9 @@ extern "C" { // compute data enum ggml_op op; + // op params - allocated as int32_t for alignment + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(uint32_t)]; + bool is_param; struct ggml_tensor * grad; @@ -1128,9 +1132,9 @@ extern "C" { int n_past, int n_dims, int mode, + int n_ctx, float freq_base, - float freq_scale, - int n_ctx); + float freq_scale); // rotary position embedding backward, i.e compute dx from dy // a - dy @@ -1139,7 +1143,8 @@ extern "C" { struct ggml_tensor * a, int n_past, int n_dims, - int mode); + int mode, + int n_ctx); // alibi position embedding // in-place, returns view(a) diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index d3054a7f..6fb55d83 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -220,7 +220,7 @@ typedef struct { static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); #define WARP_SIZE 32 -#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses +#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #define CUDA_ADD_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256 @@ -935,12 +935,18 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + float tmp = 0; // partial sum for thread in warp for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint8_t * q1 = x[i].qs + q_offset; - const uint8_t * q2 = q1 + 64; const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; @@ -953,14 +959,41 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + float4 s = {0.f, 0.f, 0.f, 0.f}; float smin = 0; - for (int l = 0; l < n; ++l) { - s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4); - s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4); + for (int l = 0; l < 4; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; + s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; } - tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin; + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif } #else @@ -1521,7 +1554,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q4_K * bq4_K = (const block_q4_K *) vbq; - const int bq8_offset = QR4_K * (iqs / QI8_1); + const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6 float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -1531,11 +1564,20 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]); - for (int i = 0; i < QR4_K; ++i) { - const int isc = bq8_offset + i; + const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; - uint8_t sc, m; - get_scale_min_k4(isc, bq4_K->scales, sc, m); + for (int i = 0; i < QR4_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]); @@ -1543,8 +1585,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( const int vi = (v >> (4*i)) & 0x0F0F0F0F; - sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product - sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q4_K with sum of q8_1 values + sumf_d += d8i * (__dp4a(vi, ui, 0) * sc[i]); // SIMD dot product + sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]); // multiply constant part of q4_K with sum of q8_1 values } return d*sumf_d - dmin*sumf_m; @@ -1745,11 +1787,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons } } -static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) { +static __global__ void mul_mat_p021_f16_f32( + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { + const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / (nchannels_y / nchannels_x); const int nrows_y = ncols_x; const int nrows_dst = nrows_x; @@ -1765,7 +1811,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const } // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; + const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; const float xi = __half2float(x[ix]); const int row_y = col_x; @@ -1793,12 +1839,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x) { + const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / channel_x_divisor; const int nrows_y = ncols_x; const int nrows_dst = nrows_x; @@ -1815,7 +1862,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous break; } - const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x; + const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; const float xi = __half2float(x[ix]); const int row_y = col_x; @@ -2324,20 +2371,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } -static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { - const dim3 block_nums(1, nrows_x, nchannels_x); +static void ggml_mul_mat_p021_f16_f32_cuda( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, + const int nchannels_x, const int nchannels_y, cudaStream_t stream) { + + const dim3 block_nums(1, nrows_x, nchannels_y); const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); + mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); } static void ggml_mul_mat_vec_nc_f16_f32_cuda( const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int channel_stride_x, cudaStream_t stream) { + const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - const dim3 block_nums(1, nrows_x, nchannels_x); + const dim3 block_nums(1, nrows_x, nchannels_y); const dim3 block_dims(WARP_SIZE, 1, 1); mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x); + (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); } static void ggml_cpy_f32_f32_cuda( @@ -2423,20 +2473,53 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); int id; CUDA_CHECK(cudaGetDevice(&id)); - +#ifdef DEBUG_CUDA_MALLOC + int nnz = 0; + size_t max_size = 0, tot_size = 0; +#endif + size_t best_diff = 1ull << 36; + int ibest = -1; for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { cuda_buffer& b = g_cuda_buffer_pool[id][i]; - if (b.size >= size && b.ptr != nullptr) { - void * ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; + if (b.ptr != nullptr) { +#ifdef DEBUG_CUDA_MALLOC + ++nnz; + tot_size += b.size; + if (b.size > max_size) max_size = b.size; +#endif + if (b.size >= size) { + size_t diff = b.size - size; + if (diff < best_diff) { + best_diff = diff; + ibest = i; + if (!best_diff) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + } } } + if (ibest >= 0) { + cuda_buffer& b = g_cuda_buffer_pool[id][ibest]; + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } +#ifdef DEBUG_CUDA_MALLOC + fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz, + (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024)); +#endif void * ptr; - CUDA_CHECK(cudaMalloc((void **) &ptr, size)); - *actual_size = size; + size_t look_ahead_size = (size_t) (1.05 * size); + look_ahead_size = 256 * ((look_ahead_size + 255)/256); + CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); + *actual_size = look_ahead_size; return ptr; } @@ -2464,7 +2547,9 @@ static size_t g_scratch_offset = 0; static int g_device_count = -1; static int g_main_device = 0; +#ifndef GGML_CUDA_FORCE_DMMV static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; +#endif static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; @@ -2487,7 +2572,9 @@ void ggml_init_cublas() { g_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; +#ifndef GGML_CUDA_FORCE_DMMV g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; +#endif } for (int id = 0; id < g_device_count; ++id) { g_tensor_split[id] /= total_vram; @@ -2512,6 +2599,9 @@ void ggml_init_cublas() { } void ggml_cuda_set_tensor_split(const float * tensor_split) { + if (tensor_split == nullptr) { + return; + } bool all_zero = true; for (int i = 0; i < g_device_count; ++i) { if (tensor_split[i] != 0.0f) { @@ -2652,6 +2742,7 @@ inline void ggml_cuda_op_mul( (void) dst; (void) src0_ddq_i; (void) i02; + (void) i1; } inline void ggml_cuda_op_gelu( @@ -2779,8 +2870,8 @@ inline void ggml_cuda_op_mul_mat_vec( #endif if (use_mul_mat_vec_q) { - int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1; - padded_row_size -= padded_row_size % MATRIX_ROW_PADDING; + const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ? + ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; size_t as; void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as); quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main); @@ -2947,13 +3038,18 @@ inline void ggml_cuda_op_rope( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - const int n_ctx = ((int32_t *) src1->data)[3]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + // RoPE alteration for extended context - const float theta_scale = powf(10000.0, -2.0f/n_dims); - const float p = ((mode & 1) == 0 ? n_past + i02 : i02); + float freq_base, freq_scale; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale; bool is_glm = mode & 4; @@ -2966,6 +3062,7 @@ inline void ggml_cuda_op_rope( rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); } + (void) src1; (void) dst; (void) src0_ddq_i; (void) src1_ddf_i; @@ -2984,11 +3081,12 @@ inline void ggml_cuda_op_diag_mask_inf( const int64_t ne01 = src0->ne[1]; const int64_t i01_diff = i01_high - i01_low; - const int n_past = ((int32_t *) src1->data)[0]; + const int n_past = ((int32_t *) dst->op_params)[0]; // compute diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main); + (void) src1; (void) dst; (void) src0_ddq_i; (void) src1_ddf_i; @@ -3056,6 +3154,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const int64_t ne11 = use_src1 ? src1->ne[1] : 1; const int64_t ne12 = use_src1 ? src1->ne[2] : 1; const int64_t ne13 = use_src1 ? src1->ne[3] : 1; + const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; + + GGML_ASSERT(ne03 == ne13); const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; @@ -3067,12 +3168,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); // strides for iteration over dims 3 and 2 - const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03; - const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1; + const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13; + const int64_t num_iters = flatten_rows ? 1 : num_iters_0; + const int64_t stride_mod = flatten_rows ? num_iters_0 : 1; const int64_t src0_stride = ne00 * ne01 * stride_mod; const int64_t src1_stride = ne10 * ne11 * stride_mod; const int64_t dst_stride = ne0 * ne1 * stride_mod; + const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; + const int64_t i03_max = flatten_rows ? 1 : ne03; + const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12); + const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02; + GGML_ASSERT(!(flatten_rows && ne02 < ne12)); + const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); @@ -3089,6 +3197,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE); const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + GGML_ASSERT(!(split && ne02 < ne12)); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); @@ -3125,7 +3234,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1]; } else { row_low = 0; - row_high = nrows0; + row_high = nrows0*i02_divisor; } if (row_low == row_high) { continue; @@ -3173,16 +3282,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]); } - const int64_t i03_max = flatten_rows ? 1 : ne03; - const int64_t i02_max = flatten_rows ? 1 : ne02; - const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; - for (int64_t i03 = 0; i03 < i03_max; i03++) { const int64_t i13 = i03 % ne13; for (int64_t i02 = 0; i02 < i02_max; i02++) { const int64_t i12 = i02 % ne12; - const int64_t i0 = i03*ne02 + i02; + const int64_t i0 = i03*i02_max + i02; // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs const int64_t i0_offset_low = row_low/rows_per_iter; @@ -3216,10 +3321,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const int64_t i11 = i13*ne12 + i12; // for split tensors the data begins at i0 == i0_offset_low - char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; - float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; + char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs; + float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride; float * src1_ddf_i = src1_ddf[id] + i11*src1_stride; - float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; + float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; // for split tensors the data pointer needs to be rounded down // to the bin edge for i03, i02 bins beyond the first @@ -3258,11 +3363,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } } - if (!src0_on_device || !src0_is_contiguous) { + if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { if (src0_is_f32) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); } else { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); } } @@ -3416,6 +3521,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + CUDA_CHECK(cudaSetDevice(g_main_device)); cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; @@ -3428,7 +3535,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main); + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main); } void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ @@ -3442,6 +3549,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; @@ -3460,7 +3569,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int row_stride_x = nb01 / sizeof(half); const int channel_stride_x = nb02 / sizeof(half); - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main); + ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main); } void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -3601,7 +3710,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { size_t size = ggml_nbytes_split(tensor, nrows_split); const size_t original_size = size; - // pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses if (ne0 % MATRIX_ROW_PADDING != 0) { size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type); @@ -3617,7 +3726,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) { } - CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(buf, buf_host, original_size, cudaMemcpyHostToDevice)); extra->data_device[id] = buf; @@ -3697,7 +3806,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; size_t offset = 0; if (tensor->op == GGML_OP_VIEW) { - memcpy(&offset, tensor->src[2]->data, sizeof(size_t)); + memcpy(&offset, tensor->op_params, sizeof(size_t)); } extra = ggml_cuda_alloc_temp_tensor_extra(); extra->data_device[g_main_device] = src0_ddc + offset; diff --git a/src/ggml-metal.m b/src/ggml-metal.m index ee205bcd..bf3f68fe 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -42,6 +42,7 @@ struct ggml_metal_context { id pipeline_##name GGML_METAL_DECL_KERNEL(add); + GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast GGML_METAL_DECL_KERNEL(scale); @@ -157,6 +158,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { fprintf(stderr, "%s: loaded %-32s %16p\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name); GGML_METAL_ADD_KERNEL(add); + GGML_METAL_ADD_KERNEL(add_row); GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul_row); GGML_METAL_ADD_KERNEL(scale); @@ -464,10 +466,16 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - [encoder setComputePipelineState:ctx->pipeline_add]; + if (ggml_nelements(src1) == ne10) { + // src1 is a row + [encoder setComputePipelineState:ctx->pipeline_add_row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_add]; + } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; const int64_t n = ggml_nelements(dst); @@ -577,7 +585,7 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - const int n_past = ((int32_t *)(src1->data))[0]; + const int n_past = ((int32_t *)(dst->op_params))[0]; [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -676,8 +684,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; } break; case GGML_TYPE_Q3_K: @@ -685,8 +693,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; } break; case GGML_TYPE_Q4_K: @@ -694,8 +702,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; } break; case GGML_TYPE_Q5_K: @@ -703,8 +711,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; } break; case GGML_TYPE_Q6_K: @@ -712,8 +720,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; } break; default: @@ -739,16 +747,22 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_Q3_K || - src0t == GGML_TYPE_Q4_K || - src0t == GGML_TYPE_Q5_K || - src0t == GGML_TYPE_Q6_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + else if (src0t == GGML_TYPE_Q3_K) { +#ifdef GGML_QKK_64 + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#else + [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#endif + } + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -792,7 +806,7 @@ void ggml_metal_graph_compute( const float eps = 1e-6f; - const int nth = 256; + const int nth = 512; [encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -800,7 +814,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -836,9 +850,10 @@ void ggml_metal_graph_compute( GGML_ASSERT((src0t == GGML_TYPE_F32)); - const int n_past = ((int32_t *) src1->data)[0]; UNUSED(n_past); - const int n_head = ((int32_t *) src1->data)[1]; - const float max_bias = ((float *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); if (__builtin_popcount(n_head) != 1) { GGML_ASSERT(false && "only power-of-two n_head implemented"); @@ -876,15 +891,14 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - - const int n_past = ((int32_t *)(src1->data))[0]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; float freq_base; float freq_scale; - memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -913,7 +927,9 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_DUP: case GGML_OP_CPY: + case GGML_OP_CONT: { if (encoder == nil) { encoder = [command_buffer computeCommandEncoder]; diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index 9f9a4fbd..987376d5 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -67,6 +67,17 @@ kernel void kernel_add( dst[tpig] = src0[tpig] + src1[tpig]; } +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] + src1[tpig % ne00]; +} + kernel void kernel_mul( device const float * src0, device const float * src1, @@ -331,26 +342,33 @@ kernel void kernel_rms_norm( threadgroup float * sum [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + device const float * x_scalar = (device const float *) x; + float4 sumf=0; + float all_sum=0; // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00] * x[i00]; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (tiisg == 0) { + sum[sgitg] = all_sum; } - // reduce threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast, simd group number is ntg / 32 + for (uint i = ntg / 32 / 2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } } - - // broadcast if (tpitg == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} sum[0] /= ne00; } @@ -359,104 +377,102 @@ kernel void kernel_rms_norm( const float mean = sum[0]; const float scale = 1.0f/sqrt(mean + eps); - device float * y = dst + tgpig*ne00; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + device float4 * y = (device float4 *) (dst + tgpig*ne00); + device float * y_scalar = (device float *) y; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { y[i00] = x[i00] * scale; } + if (tpitg == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} + } +} + +// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) +float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { + float d = qb_curr->d; + float4 acc = 0.f; + device uint16_t * qs = ((device uint16_t *)qb_curr + 1); + for (int i = 0; i < 16; i+=2) { + acc[0] += yl[i] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); +} + +// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) +float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { + float d = qb_curr->d; + float m = qb_curr->m; + float4 acc = 0.f; + device uint16_t * qs = ((device uint16_t *)qb_curr + 2); + for (int i = 0; i < 16; i+=2) { + acc[0] += yl[i] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + } + return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; } // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 -kernel void kernel_mul_mat_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne10, - constant int64_t & ne0, - constant int64_t & ne01[[buffer(4)]], - uint2 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +template +void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, + int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, + uint2 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; - device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; + device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; device const float * y = (device const float *) src1 + r1*ne10; - block_q4_0 qb_curr, qb_next; float4 y_curr[8]; // src1 vector cache float sumf[N_DST]={0.f}, all_sum; thread float * yl=(thread float *)y_curr; - // bootstrap - qb_curr = x[tiisg]; // each thread in a SIMD group deals with 1 block. for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - float sumy = 0; for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - sumy *= (-8.f); for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - sumf[row] += d * acc; - qb_curr = qb_next; + sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); } } - if (nb % N_SIMDWIDTH == 0) { - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } else { - + // from now loads two rows every time and 16 blocks per row + int ir = tiisg / (N_SIMDWIDTH / 2); + int ib = tiisg % (N_SIMDWIDTH / 2); + for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) { + int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start float sumy = 0; for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); + y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i); sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; } - sumy *= (-8.f); - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); + for (int row = 0; row < N_DST; row+=2) { + if (nb_start + ib < nb) { + sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * acc; - } - qb_curr = qb_next; + } + } - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; } } } -kernel void kernel_mul_mat_q4_1_f32( +kernel void kernel_mul_mat_q4_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -467,80 +483,21 @@ kernel void kernel_mul_mat_q4_1_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nb = ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; - device const float * y = (device const float *) src1 + r1*ne10; - block_q4_1 qb_curr, qb_next; - float4 y_curr[8]; // src1 vector cache - float sumf[N_DST]={0.f}, all_sum; - thread float * yl=(thread float *)y_curr; - - // bootstrap - qb_curr = x[tiisg]; - // each thread in a SIMD group deals with 1 block. - for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - const float d = qb_curr.d; - const float m = qb_curr.m; - float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - sumf[row] += d * acc + m * sumy; - qb_curr = qb_next; - } - } - - if (nb % N_SIMDWIDTH == 0) { - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } else { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - const float d = qb_curr.d; - const float m = qb_curr.m; - float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * acc + m * sumy; - } - qb_curr = qb_next; + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); +} - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } +kernel void kernel_mul_mat_q4_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne01[[buffer(4)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( @@ -1263,111 +1220,137 @@ kernel void kernel_mul_mat_q2_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne01[[buffer(4)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; - float sumf = 0; + const int step = sizeof(block_q2_K) * nb; #if QK_K == 256 - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid%4; // 0...3 - const int ip = il/2; // 0 or 1 - const int shift1 = 4*(il%2);// 0 or 4 - const int shift2 = shift1+2;// 2 or 6 - const int n = 8; - const int is = 4*il + (n*ir)/16; - - const int y_offset = 64*il + n*ir; - const int q_offset = 32*ip + n*ir; - - for (int i = tpitg.x; i < nb; i += tptg.x) { - - device const uint8_t * q = x[i].qs + q_offset; - device const uint8_t * scales = x[i].scales + is; + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int im = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + } - uint8_t d1 = scales[0] & 0xF; - uint8_t d2 = scales[2] & 0xF; - uint8_t m1 = scales[0] >> 4; - uint8_t m2 = scales[2] >> 4; + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const half * dh = &x[ib].d; - device const float * y = yy + i*QK_K + y_offset; + for (int row = 0; row < N_DST; row++) { - float2 s = {0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { - s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); - s[1] += y[l+32] * ((q[l] >> shift2) & 3); - smin += y[l+ 0] * m1 + y[l+32] * m2; + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += step/2; + sc += step; + dh += step/2; } - const float dall = (float)x[i].d; - const float dmin = (float)x[i].dmin; - - sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin; - + y4 += 4 * QK_K; } #else - const int il = 4 * tpitg.x; + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0...1 - uint32_t aux[2]; - thread const uint8_t * d = (thread const uint8_t *)aux; - thread const uint8_t * m = (thread const uint8_t *)aux + 4; + device const float * y4 = y + ix * QK_K + 8 * it; - for (int i = tpitg.y; i < nb; i += tptg.y) { + for (int ib = ix; ib < nb; ib += 16) { - device const uint8_t * q = x[i].qs + il; - device const float * y = yy + i*QK_K + il; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; + } - const float dall = (float)x[i].d; - const float dmin = (float)x[i].dmin; + device const uint8_t * sc = (device const uint8_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = &x[ib].d; - device const uint32_t * a = (device const uint32_t *)x[i].scales; - aux[0] = a[0] & 0x0f0f0f0f; - aux[1] = (a[0] >> 4) & 0x0f0f0f0f; + for (int row = 0; row < N_DST; row++) { - for (int l = 0; l < 4; ++l) { - sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0]) - + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1]) - + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2]) - + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]); + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); + + qs += step/2; + sc += step; + dh += step/2; } + + y4 += 16 * QK_K; } #endif - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; + } } } +#if QK_K == 256 kernel void kernel_mul_mat_q3_K_f32( device const void * src0, device const float * src1, @@ -1376,40 +1359,41 @@ kernel void kernel_mul_mat_q3_K_f32( constant int64_t & ne10, constant int64_t & ne0, constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; -#if QK_K == 256 + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb; + device const float * yy = (device const float *) src1 + r1*ne10; - const uint8_t m3 = 3; - const int8_t m4 = 4; + float yl[16]; const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; - const int tid = tpitg.y; // expecting 16 + const int tid = tiisg/2; + const int ix = tiisg%2; const int ip = tid/8; // 0 or 1 const int il = tid/2 - 4*ip; // 0...3 const int ir = tid%2; const int n = 8; const int l0 = n*ir; - const uint8_t m = 1 << (4*ip + il); + const uint16_t m1 = 1 << (4*ip + il); + const uint16_t m2 = m1 << 8; const int shift = 2*il; + const uint16_t qm1 = 0x0003 << shift; + const uint16_t qm2 = 0x0300 << shift; + const int32_t v1 = 4 << shift; + const int32_t v2 = 1024 << shift; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + 2*(il/2); @@ -1418,226 +1402,315 @@ kernel void kernel_mul_mat_q3_K_f32( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - //float sumf = 0; - float sumf1 = 0, sumf2 = 0; - for (int i = tpitg.x; i < nb; i += tptg.x) { - - const float d_all = (float)(x[i].d); + const int step = sizeof(block_q3_K) * nb / 2; - device const uint8_t * q = x[i].qs + q_offset; - device const uint8_t * h = x[i].hmask + l0; - device const float * y = yy + i * QK_K + y_offset; + device const float * y1 = yy + ix*QK_K + y_offset; - device const uint16_t * a = (device const uint16_t *)x[i].scales; - const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); + float sumf1[2] = {0.f}, sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 2) { - float s = 0; - for (int l = 0; l < n; ++l) { - s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4)); + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; + yl[l+8] = y1[l+16]; } - float d = d_all * s; - sumf1 += d * scales[0]; - sumf2 += d; - //sumf += d_all * s * (scales[0] - 32); - s = 0; - for (int l = 0; l < n; ++l) { - s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4)); - } - d = d_all * s; - sumf1 += d * scales[1]; - sumf2 += d; - //sumf += d_all * s * (scales[1] - 32); + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; - } + for (int row = 0; row < 2; ++row) { - //sum[ith] = sumf; - sum[ith] = sumf1 - 32.f*sumf2; -#else - const int il = 4 * tpitg.x; // 0, 4, 8, 12 - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 + const float d_all = (float)dh[0]; + const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); - float sumf = 0; - - for (int i = tpitg.y; i < nb; i += tptg.y) { - - const float d_all = (float)(x[i].d); - - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].hmask + in; - device const float * y = yy + i * QK_K + il; + float s1 = 0, s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2]; + s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); + s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); + } + float d = d_all * (s1 + 1.f/256.f * s2); + sumf1[row] += d * scales[0]; + sumf2[row] += d; + + s1 = s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2+8]; + s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); + s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); + } + d = d_all * (s1 + 1.f/256.f * s2); + sumf1[row] += d * scales[1]; + sumf2[row] += d; - const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); - const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); - const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); - const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + q += step; + h += step; + a += step; + dh += step; - for (int l = 0; l < 4; ++l) { - const uint8_t hm = h[l] >> im; - sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4)) - + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4)) - + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4)) - + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4)); } - } - - sum[ith] = sumf; - -#endif + y1 += 2 * QK_K; - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; } + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = tot; + } + } } - -kernel void kernel_mul_mat_q4_K_f32( +#else +kernel void kernel_mul_mat_q3_K_f32( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne1, uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + const int row = 2 * r0 + sgitg; - device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb; + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; device const float * yy = (device const float *) src1 + r1*ne10; + const int ix = tiisg/4; + const int il = 4 * (tiisg%4);// 0, 4, 8, 12 + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 - float sumf = 0; + float2 sum = {0.f, 0.f}; + + for (int i = ix; i < nb; i += 8) { + + const float d_all = (float)(x[i].d); + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); + device const uint16_t * s = (device const uint16_t *)(x[i].scales); + device const float * y = yy + i * QK_K + il; + + const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); + const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; + const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; + const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; + + for (int l = 0; l < 4; l += 2) { + const uint16_t hm = h[l/2] >> im; + sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) + + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) + + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) + + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); + sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) + + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) + + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) + + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); + } + + } + const float sumf = sum[0] + sum[1] * 1.f/256.f; + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + row] = tot; + } + +} +#endif #if QK_K == 256 +kernel void kernel_mul_mat_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne01[[buffer(4)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 4; + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int im = it/4; // 0 or 1 + const int ir = it%4; // 0...3 - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[16]; + float yh[16]; + float sumf[N_DST]={0.f}, all_sum; - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; + const int step = sizeof(block_q4_K) * nb / 2; - uchar2 sc1, sc2, sc3, sc4; + device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir; - for (int i = tpitg.x; i < nb; i += tptg.x) { + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - device const uint8_t * q1 = (x + i)->qs + q_offset; - device const uint8_t * q2 = q1 + 64; - device const float * y1 = yy + i*QK_K + y_offset; - device const float * y2 = y1 + 128; + for (int ib = ix; ib < nb; ib += 4) { - const float dall = (float)((x + i)->d); - const float dmin = (float)((x + i)->dmin); + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + } - device const uint16_t * a = (device const uint16_t *)(x + i)->scales; - sc1 = as_type((uint16_t)(a[im+0] & kmask1)); - sc2 = as_type((uint16_t)(a[im+2] & kmask1)); - sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); - sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const half * dh = &x[ib].d; - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { + for (int row = 0; row < N_DST; row++) { - s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4); - s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4); - smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); + acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); + acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); + acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); + acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); + acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); + acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + } + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + sc += step; + dh += step; } - sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; + } } +} #else - uint16_t aux16[2]; - thread const uint8_t * scales = (thread const uint8_t *)aux16; +kernel void kernel_mul_mat_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne01[[buffer(4)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int il = 4*tpitg.x; + const int ix = tiisg/4; // 0...7 + const int it = tiisg%4; // 0...3 - for (int i = tpitg.y; i < nb; i += tptg.y) { + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[8]; + float yh[8]; + float sumf[N_DST]={0.f}, all_sum; - device const uint8_t * q = x[i].qs + il; - device const float * y = yy + i * QK_K + il; + const int step = sizeof(block_q4_K) * nb / 2; - const float d = (float)x[i].d[0]; - const float m = (float)x[i].d[1]; + device const float * y4 = y + ix * QK_K + 8 * it; - device const uint16_t * a = (device const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; + uint16_t sc16[4]; - for (int l = 0; l < 4; ++l) { - sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16]) - + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]); + for (int ib = ix; ib < nb; ib += 8) { + + float2 sumy = {0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i] = y4[i+ 0]; sumy[0] += yl[i]; + yh[i] = y4[i+32]; sumy[1] += yh[i]; } - } -#endif - sum[ith] = sumf; + device const uint16_t * sc = (device const uint16_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = x[ib].d; - // - // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & 0x000f; + sc16[1] = sc[0] & 0x0f00; + sc16[2] = sc[0] & 0x00f0; + sc16[3] = sc[0] & 0xf000; + + float2 acc1 = {0.f, 0.f}; + float2 acc2 = {0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); + acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); + acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); + acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + + (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - + dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); + + qs += step; + sc += step; + dh += step; + } + + y4 += 8 * QK_K; } - //// accumulate the sum from all threads in the threadgroup - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = nth/2; i > 0; i /= 2) { - // if (ith < i) { - // sum[ith] += sum[ith + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - //if (ith == 0) { - // dst[r1*ne0 + r0] = sum[0]; - //} + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; + } + } } +#endif kernel void kernel_mul_mat_q5_K_f32( device const void * src0, @@ -1646,39 +1719,39 @@ kernel void kernel_mul_mat_q5_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb; device const float * yy = (device const float *) src1 + r1*ne10; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf[2]={0.f}; - float sumf = 0; + const int step = sizeof(block_q5_K) * nb; #if QK_K == 256 +# + float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 4; - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; + const int tid = tiisg/4; + const int ix = tiisg%4; + const int im = tid/4; + const int ir = tid%4; + const int n = 8; - const int l0 = n*(2*ir + in); + const int l0 = n*ir; const int q_offset = 32*im + l0; const int y_offset = 64*im + l0; @@ -1687,78 +1760,113 @@ kernel void kernel_mul_mat_q5_K_f32( const uint8_t hm3 = hm1 << 4; const uint8_t hm4 = hm2 << 4; - uchar2 sc1, sc2, sc3, sc4; + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - for (int i = tpitg.x; i < nb; i += tptg.x) { + device const float * y1 = yy + ix*QK_K + y_offset; - device const uint8_t * q1 = (x + i)->qs + q_offset; - device const uint8_t * q2 = q1 + 64; - device const uint8_t * qh = (x + i)->qh + l0; - device const float * y1 = yy + i*QK_K + y_offset; - device const float * y2 = y1 + 128; + for (int i = ix; i < nb; i += 4) { - const float dall = (float)((x + i)->d); - const float dmin = (float)((x + i)->dmin); + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + im; - device const uint16_t * a = (device const uint16_t *)(x + i)->scales; - sc1 = as_type((uint16_t)(a[im+0] & kmask1)); - sc2 = as_type((uint16_t)(a[im+2] & kmask1)); - sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); - sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { + for (int row = 0; row < 2; ++row) { + + device const uint8_t * q2 = q1 + 64; - s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0)); - s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0)); - s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0)); - s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0)); - smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + uint8_t h = qh[l]; + acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); + acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); + acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); + acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + qh += step; + dh += step/2; + a += step/2; } - sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; + + y1 += 4 * QK_K; } #else - const int il = 4 * tpitg.x; // 0, 4, 8, 12 - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 + float yl[8], yh[8]; + + const int il = 4 * (tiisg/8); // 0, 4, 8, 12 + const int ix = tiisg%8; + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 - for (int i = tpitg.y; i < nb; i += tptg.y) { + device const float * y = yy + ix*QK_K + il; - const float d = (float)x[i].d; + for (int i = ix; i < nb; i += 8) { + + for (int l = 0; l < 4; ++l) { + yl[l+0] = y[l+ 0]; + yl[l+4] = y[l+16]; + yh[l+0] = y[l+32]; + yh[l+4] = y[l+48]; + } + + device const half * dh = &x[i].d; device const uint8_t * q = x[i].qs + il; device const uint8_t * h = x[i].qh + in; device const int8_t * s = x[i].scales; - device const float * y = yy + i*QK_K + il; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> im; - sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16)) - + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16)) - + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16)) - + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16)); + for (int row = 0; row < 2; ++row) { + + const float d = dh[0]; + + float2 acc = {0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> im; + acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) + + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); + acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) + + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); + } + sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); + + q += step; + h += step; + s += step; + dh += step/2; + } + + y += 8 * QK_K; } #endif - sum[ith] = sumf; - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + for (int row = 0; row < 2; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = tot; + } } } @@ -1770,10 +1878,9 @@ kernel void kernel_mul_mat_q6_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -1785,19 +1892,18 @@ kernel void kernel_mul_mat_q6_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const int row = 2 * r0 + sgitg; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; float sumf = 0; #if QK_K == 256 - // Note: we absolutely assume that tptg.y = 16 and QK_K = 256! - const int iqs = 16 * tpitg.y; - const int ip = iqs / 128; // 0 or 1 - const int il = (iqs - 128*ip)/16; // 0...7 + const int tid = tiisg/2; + const int ix = tiisg%2; + const int ip = tid/8; // 0 or 1 + const int il = tid%8; const int n = 4; const int l0 = n*il; const int is = 8*ip + l0/16; @@ -1806,9 +1912,10 @@ kernel void kernel_mul_mat_q6_K_f32( const int q_offset_l = 64*ip + l0; const int q_offset_h = 32*ip + l0; - for (int i = tpitg.x; i < nb; i += tptg.x) { + for (int i = ix; i < nb; i += 2) { - device const uint8_t * ql = x[i].ql + q_offset_l; + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; device const int8_t * sc = x[i].scales + is; @@ -1818,19 +1925,21 @@ kernel void kernel_mul_mat_q6_K_f32( float4 sums = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); } sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); } + #else - const int il = 4*tpitg.x; // 0, 4, 8, 12 + const int ix = tiisg/4; + const int il = 4*(tiisg%4); - for (int i = tpitg.y; i < nb; i += tptg.y) { + for (int i = ix; i < nb; i += 8) { device const float * y = yy + i * QK_K + il; device const uint8_t * ql = x[i].ql + il; device const uint8_t * qh = x[i].qh + il; @@ -1850,23 +1959,8 @@ kernel void kernel_mul_mat_q6_K_f32( #endif - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + row] = tot; } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; - } - } diff --git a/src/ggml.c b/src/ggml.c index c56a3d0e..14658ae4 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -4590,6 +4590,7 @@ struct ggml_tensor * ggml_new_tensor_impl( /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, /*.op =*/ GGML_OP_NONE, + /*.op_params =*/ {0}, /*.is_param =*/ false, /*.grad =*/ NULL, /*.src =*/ { NULL }, @@ -4969,6 +4970,11 @@ struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * return tensor; } +static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) { + assert(params_size <= GGML_MAX_OP_PARAMS); + memcpy(tensor->op_params, params, params_size); +} + struct ggml_tensor * ggml_view_tensor( struct ggml_context * ctx, const struct ggml_tensor * src) { @@ -5019,7 +5025,6 @@ struct ggml_tensor * ggml_dup_impl( result->op = GGML_OP_DUP; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5143,23 +5148,13 @@ struct ggml_tensor * ggml_acc_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5); - - ((int32_t *) c->data)[0] = nb1; - ((int32_t *) c->data)[1] = nb2; - ((int32_t *) c->data)[2] = nb3; - ((int32_t *) c->data)[3] = offset; - ((int32_t *) c->data)[4] = inplace ? 1 : 0; - - ggml_scratch_load(ctx); + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ACC; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -5332,7 +5327,6 @@ struct ggml_tensor * ggml_sqr_impl( result->op = GGML_OP_SQR; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5366,7 +5360,6 @@ struct ggml_tensor * ggml_sqrt_impl( result->op = GGML_OP_SQRT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5401,7 +5394,6 @@ struct ggml_tensor * ggml_log_impl( result->op = GGML_OP_LOG; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5434,7 +5426,6 @@ struct ggml_tensor * ggml_sum( result->op = GGML_OP_SUM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5461,7 +5452,6 @@ struct ggml_tensor * ggml_sum_rows( result->op = GGML_OP_SUM_ROWS; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5484,7 +5474,6 @@ struct ggml_tensor * ggml_mean( result->op = GGML_OP_MEAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5508,7 +5497,6 @@ struct ggml_tensor * ggml_argmax( result->op = GGML_OP_ARGMAX; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5586,7 +5574,6 @@ struct ggml_tensor * ggml_abs_impl( result->op = GGML_OP_ABS; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5621,7 +5608,6 @@ struct ggml_tensor * ggml_sgn_impl( result->op = GGML_OP_SGN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5655,7 +5641,6 @@ struct ggml_tensor * ggml_neg_impl( result->op = GGML_OP_NEG; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5689,7 +5674,6 @@ struct ggml_tensor * ggml_step_impl( result->op = GGML_OP_STEP; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5723,7 +5707,6 @@ struct ggml_tensor * ggml_tanh_impl( result->op = GGML_OP_TANH; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5757,7 +5740,6 @@ struct ggml_tensor * ggml_elu_impl( result->op = GGML_OP_ELU; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5791,7 +5773,6 @@ struct ggml_tensor * ggml_relu_impl( result->op = GGML_OP_RELU; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5825,7 +5806,6 @@ struct ggml_tensor * ggml_gelu_impl( result->op = GGML_OP_GELU; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5859,7 +5839,6 @@ struct ggml_tensor * ggml_gelu_quick_impl( result->op = GGML_OP_GELU_QUICK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5893,7 +5872,6 @@ struct ggml_tensor * ggml_silu_impl( result->op = GGML_OP_SILU; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -5948,10 +5926,11 @@ struct ggml_tensor * ggml_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + // TODO: maybe store epsilon here? + result->op = GGML_OP_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; // TODO: maybe store epsilon here? return result; } @@ -5980,10 +5959,11 @@ struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + // TODO: maybe store epsilon here? + result->op = GGML_OP_RMS_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; // TODO: maybe store epsilon here? return result; } @@ -6136,23 +6116,13 @@ struct ggml_tensor * ggml_set_impl( // make a view of the destination struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 5); - - (( int32_t * ) c->data)[0] = nb1; - (( int32_t * ) c->data)[1] = nb2; - (( int32_t * ) c->data)[2] = nb3; - (( int32_t * ) c->data)[3] = offset; - (( int32_t * ) c->data)[4] = inplace ? 1 : 0; - - ggml_scratch_load(ctx); + int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; + ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_SET; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -6277,7 +6247,6 @@ struct ggml_tensor * ggml_cont_impl( result->op = GGML_OP_CONT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6321,7 +6290,6 @@ struct ggml_tensor * ggml_reshape( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6346,7 +6314,6 @@ struct ggml_tensor * ggml_reshape_1d( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6372,7 +6339,6 @@ struct ggml_tensor * ggml_reshape_2d( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6399,7 +6365,6 @@ struct ggml_tensor * ggml_reshape_3d( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6428,7 +6393,6 @@ struct ggml_tensor * ggml_reshape_4d( result->op = GGML_OP_RESHAPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6450,19 +6414,11 @@ struct ggml_tensor * ggml_view_1d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, &ne0, (char *) a->data + offset); ggml_format_name(result, "%s (view)", a->name); - ggml_scratch_save(ctx); - - struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(offs, "offset"); - memcpy(offs->data, &offset, 2*sizeof(int32_t)); - - ggml_scratch_load(ctx); + ggml_set_op_params(result, &offset, sizeof(offset)); result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = offs; return result; } @@ -6488,13 +6444,7 @@ struct ggml_tensor * ggml_view_2d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, (char *) a->data + offset); ggml_format_name(result, "%s (view)", a->name); - ggml_scratch_save(ctx); - - struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(offs, "offset"); - memcpy(offs->data, &offset, 2*sizeof(int32_t)); - - ggml_scratch_load(ctx); + ggml_set_op_params(result, &offset, sizeof(offset)); result->nb[1] = nb1; result->nb[2] = result->nb[1]*ne1; @@ -6503,8 +6453,6 @@ struct ggml_tensor * ggml_view_2d( result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = offs; return result; } @@ -6532,13 +6480,7 @@ struct ggml_tensor * ggml_view_3d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, (char *) a->data + offset); ggml_format_name(result, "%s (view)", a->name); - ggml_scratch_save(ctx); - - struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(offs, "offset"); - memcpy(offs->data, &offset, 2*sizeof(int32_t)); - - ggml_scratch_load(ctx); + ggml_set_op_params(result, &offset, sizeof(offset)); result->nb[1] = nb1; result->nb[2] = nb2; @@ -6547,8 +6489,6 @@ struct ggml_tensor * ggml_view_3d( result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = offs; return result; } @@ -6578,13 +6518,7 @@ struct ggml_tensor * ggml_view_4d( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, (char *) a->data + offset); ggml_format_name(result, "%s (view)", a->name); - ggml_scratch_save(ctx); - - struct ggml_tensor * offs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(offs, "offset"); - memcpy(offs->data, &offset, 2*sizeof(int32_t)); - - ggml_scratch_load(ctx); + ggml_set_op_params(result, &offset, sizeof(offset)); result->nb[1] = nb1; result->nb[2] = nb2; @@ -6593,8 +6527,6 @@ struct ggml_tensor * ggml_view_4d( result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = offs; return result; } @@ -6655,22 +6587,9 @@ struct ggml_tensor * ggml_permute( result->op = GGML_OP_PERMUTE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - - if (is_node) { - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); - ((int32_t *) b->data)[0] = axis0; - ((int32_t *) b->data)[1] = axis1; - ((int32_t *) b->data)[2] = axis2; - ((int32_t *) b->data)[3] = axis3; - - ggml_scratch_load(ctx); - - result->src[2] = b; - } + int32_t params[] = { axis0, axis1, axis2, axis3 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); return result; } @@ -6698,7 +6617,6 @@ struct ggml_tensor * ggml_transpose( result->op = GGML_OP_TRANSPOSE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6776,7 +6694,6 @@ struct ggml_tensor * ggml_diag( result->op = GGML_OP_DIAG; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6797,19 +6714,12 @@ struct ggml_tensor * ggml_diag_mask_inf_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = inplace ? 1 : 0; - - ggml_scratch_load(ctx); + int32_t params[] = { n_past, inplace ? 1 : 0 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_DIAG_MASK_INF; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -6844,20 +6754,12 @@ struct ggml_tensor * ggml_diag_mask_zero_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 2); - ggml_set_name(b, "n_past, inplace"); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = inplace ? 1 : 0; - - ggml_scratch_load(ctx); + int32_t params[] = { n_past, inplace ? 1 : 0 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_DIAG_MASK_ZERO; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -6893,7 +6795,6 @@ struct ggml_tensor * ggml_soft_max_impl( result->op = GGML_OP_SOFT_MAX; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; return result; } @@ -6956,9 +6857,9 @@ struct ggml_tensor * ggml_rope_impl( int n_past, int n_dims, int mode, + int n_ctx, float freq_base, float freq_scale, - int n_ctx, bool inplace) { GGML_ASSERT(n_past >= 0); bool is_node = false; @@ -6969,23 +6870,14 @@ struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = n_dims; - ((int32_t *) b->data)[2] = mode; - ((int32_t *) b->data)[3] = n_ctx; - memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float)); - memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float)); - - ggml_scratch_load(ctx); + int32_t params[6] = { n_past, n_dims, mode, n_ctx }; + memcpy(params + 4, &freq_base, sizeof(float)); + memcpy(params + 5, &freq_scale, sizeof(float)); + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_ROPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -6997,7 +6889,7 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false); } struct ggml_tensor * ggml_rope_inplace( @@ -7007,7 +6899,7 @@ struct ggml_tensor * ggml_rope_inplace( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true); } struct ggml_tensor * ggml_rope_custom_inplace( @@ -7016,10 +6908,10 @@ struct ggml_tensor * ggml_rope_custom_inplace( int n_past, int n_dims, int mode, + int n_ctx, float freq_base, - float freq_scale, - int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true); + float freq_scale) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true); } // ggml_rope_back @@ -7029,7 +6921,8 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * a, int n_past, int n_dims, - int mode) { + int mode, + int n_ctx) { GGML_ASSERT(n_past >= 0); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -7041,21 +6934,12 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - ggml_set_name(b, "n_past, n_dims, mode"); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = n_dims; - ((int32_t *) b->data)[2] = mode; - - ggml_scratch_load(ctx); + int32_t params[] = { n_past, n_dims, mode, n_ctx }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_ROPE_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -7080,21 +6964,13 @@ struct ggml_tensor * ggml_alibi( //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = ggml_view_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = n_head; - GGML_ASSERT(sizeof(float) == sizeof(int32_t)); - (((float *) b->data)[2]) = bias_max; - - ggml_scratch_load(ctx); + int32_t op_params[3] = { n_past, n_head }; + memcpy(op_params + 2, &bias_max, sizeof(float)); + ggml_set_op_params(result, &op_params, sizeof(op_params)); result->op = GGML_OP_ALIBI; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -7116,19 +6992,12 @@ struct ggml_tensor * ggml_clamp( // TODO: when implement backward, fix this: struct ggml_tensor * result = ggml_view_tensor(ctx, a); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 2); - - ((float *) b->data)[0] = min; - ((float *) b->data)[1] = max; - - ggml_scratch_load(ctx); + float params[] = { min, max }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_CLAMP; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -7161,18 +7030,13 @@ GGML_API struct ggml_tensor * ggml_conv_1d( }; struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - ggml_scratch_save(ctx); - struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - ((int32_t*)c->data)[0] = s0; - ((int32_t*)c->data)[1] = p0; - ((int32_t*)c->data)[2] = d0; - ggml_scratch_load(ctx); + int32_t params[] = { s0, p0, d0 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_CONV_1D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -7205,21 +7069,13 @@ struct ggml_tensor* ggml_conv_2d( }; struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - ggml_scratch_save(ctx); - struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6); - ((int32_t*)c->data)[0] = s0; - ((int32_t*)c->data)[1] = s1; - ((int32_t*)c->data)[2] = p0; - ((int32_t*)c->data)[3] = p1; - ((int32_t*)c->data)[4] = d0; - ((int32_t*)c->data)[5] = d1; - ggml_scratch_load(ctx); + int32_t params[] = { s0, s1, p0, p1, d0, d1 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_CONV_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; @@ -7243,7 +7099,7 @@ static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) { return (ins + 2 * p - ks) / s + 1; } -// ggml_pool_2d +// ggml_pool_1d struct ggml_tensor* ggml_pool_1d( struct ggml_context * ctx, @@ -7266,18 +7122,12 @@ struct ggml_tensor* ggml_pool_1d( }; struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne); - ggml_scratch_save(ctx); - struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); - ((int32_t*)c->data)[0] = op; - ((int32_t*)c->data)[1] = k0; - ((int32_t*)c->data)[2] = s0; - ((int32_t*)c->data)[3] = p0; - ggml_scratch_load(ctx); + int32_t params[] = { op, k0, s0, p0 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_POOL_1D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = c; return result; } @@ -7309,21 +7159,12 @@ struct ggml_tensor* ggml_pool_2d( }; struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); - ggml_scratch_save(ctx); - struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7); - ((int32_t*)c->data)[0] = op; - ((int32_t*)c->data)[1] = k0; - ((int32_t*)c->data)[2] = k1; - ((int32_t*)c->data)[3] = s0; - ((int32_t*)c->data)[4] = s1; - ((int32_t*)c->data)[5] = p0; - ((int32_t*)c->data)[6] = p1; - ggml_scratch_load(ctx); + int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_POOL_2D; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = c; return result; } @@ -7482,21 +7323,12 @@ struct ggml_tensor * ggml_win_part( struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - - ((int32_t *) b->data)[0] = npx; - ((int32_t *) b->data)[1] = npy; - ((int32_t *) b->data)[2] = w; - - ggml_scratch_load(ctx); + int32_t params[] = { npx, npy, w }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_WIN_PART; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = b; return result; } @@ -7521,19 +7353,12 @@ struct ggml_tensor * ggml_win_unpart( const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); - ggml_scratch_save(ctx); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - - ((int32_t *) b->data)[0] = w; - - ggml_scratch_load(ctx); + int32_t params[] = { w }; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_WIN_UNPART; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = NULL; - result->src[2] = b; return result; } @@ -7551,19 +7376,13 @@ struct ggml_tensor * ggml_map_unary_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_scratch_save(ctx); + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; - - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_UNARY; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[2] = addr_tensor; return result; } @@ -7598,20 +7417,14 @@ struct ggml_tensor * ggml_map_binary_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_scratch_save(ctx); - - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_BINARY; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = addr_tensor; return result; } @@ -7645,19 +7458,13 @@ struct ggml_tensor * ggml_map_custom1_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_scratch_save(ctx); - - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_CUSTOM1; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[2] = addr_tensor; return result; } @@ -7690,20 +7497,14 @@ struct ggml_tensor * ggml_map_custom2_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_scratch_save(ctx); - - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_CUSTOM2; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = addr_tensor; return result; } @@ -7739,21 +7540,15 @@ struct ggml_tensor * ggml_map_custom3_impl_f32( is_node = true; } - struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - - ggml_scratch_save(ctx); - - struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t)); - *((void (**)(void))addr_tensor->data) = (void (*)(void))fun; + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_scratch_load(ctx); + ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); result->op = GGML_OP_MAP_CUSTOM3; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = addr_tensor; - result->src[3] = c; + result->src[2] = c; return result; } @@ -8981,21 +8776,17 @@ static void ggml_compute_forward_acc_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - GGML_ASSERT(opt0->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(opt0) == 5); - // view src0 and dst with these strides and data offset inbytes during acc // nb0 is implicitely element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) opt0->data)[0]; - size_t nb2 = ((int32_t *) opt0->data)[1]; - size_t nb3 = ((int32_t *) opt0->data)[2]; - size_t offset = ((int32_t *) opt0->data)[3]; - bool inplace = (bool) ((int32_t *) opt0->data)[4]; + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; if (!inplace && (params->type == GGML_TASK_INIT)) { // memcpy needs to be synchronized across threads to avoid race conditions. @@ -9064,13 +8855,12 @@ static void ggml_compute_forward_acc( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_acc_f32(params, src0, src1, opt0, dst); + ggml_compute_forward_acc_f32(params, src0, src1, dst); } break; case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -11090,21 +10880,17 @@ static void ggml_compute_forward_set_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - GGML_ASSERT(opt0->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(opt0) == 5); - // view src0 and dst with these strides and data offset inbytes during set // nb0 is implicitely element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) opt0->data)[0]; - size_t nb2 = ((int32_t *) opt0->data)[1]; - size_t nb3 = ((int32_t *) opt0->data)[2]; - size_t offset = ((int32_t *) opt0->data)[3]; - bool inplace = (bool) ((int32_t *) opt0->data)[4]; + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; if (!inplace && (params->type == GGML_TASK_INIT)) { // memcpy needs to be synchronized across threads to avoid race conditions. @@ -11164,13 +10950,12 @@ static void ggml_compute_forward_set( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_set_f32(params, src0, src1, opt0, dst); + ggml_compute_forward_set_f32(params, src0, src1, dst); } break; case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -11566,17 +11351,14 @@ static void ggml_compute_forward_diag( static void ggml_compute_forward_diag_mask_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst, const float value) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 2); const int ith = params->ith; const int nth = params->nth; - const int n_past = ((int32_t *) src1->data)[0]; - const bool inplace = (bool)((int32_t *) src1->data)[1]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const bool inplace = (bool)((int32_t *) dst->op_params)[1]; GGML_ASSERT(n_past >= 0); @@ -11619,12 +11401,11 @@ static void ggml_compute_forward_diag_mask_f32( static void ggml_compute_forward_diag_mask_inf( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, -INFINITY); + ggml_compute_forward_diag_mask_f32(params, src0, dst, -INFINITY); } break; default: { @@ -11636,12 +11417,11 @@ static void ggml_compute_forward_diag_mask_inf( static void ggml_compute_forward_diag_mask_zero( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_diag_mask_f32(params, src0, src1, dst, 0); + ggml_compute_forward_diag_mask_f32(params, src0, dst, 0); } break; default: { @@ -11839,20 +11619,17 @@ static void ggml_compute_forward_soft_max_back( static void ggml_compute_forward_alibi_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { assert(params->ith == 0); - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 3); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int n_past = ((int32_t *) src1->data)[0]; - const int n_head = ((int32_t *) src1->data)[1]; - const float max_bias = ((float *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); assert(n_past >= 0); @@ -11905,20 +11682,17 @@ static void ggml_compute_forward_alibi_f32( static void ggml_compute_forward_alibi_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { assert(params->ith == 0); - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 3); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int n_past = ((int32_t *) src1->data)[0]; - const int n_head = ((int32_t *) src1->data)[1]; - const float max_bias = ((float *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_head = ((int32_t *) dst->op_params)[1]; + float max_bias; + memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); assert(n_past >= 0); @@ -11971,16 +11745,15 @@ static void ggml_compute_forward_alibi_f16( static void ggml_compute_forward_alibi( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_alibi_f16(params, src0, src1, dst); + ggml_compute_forward_alibi_f16(params, src0, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_alibi_f32(params, src0, src1, dst); + ggml_compute_forward_alibi_f32(params, src0, dst); } break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -12010,19 +11783,17 @@ static void ggml_compute_forward_alibi( static void ggml_compute_forward_clamp_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { assert(params->ith == 0); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_nelements(src1) == 2); - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const float min = ((float *) src1->data)[0]; - const float max = ((float *) src1->data)[1]; + float min; + float max; + memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -12052,12 +11823,11 @@ static void ggml_compute_forward_clamp_f32( static void ggml_compute_forward_clamp( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_clamp_f32(params, src0, src1, dst); + ggml_compute_forward_clamp_f32(params, src0, dst); } break; case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -12087,10 +11857,7 @@ static void ggml_compute_forward_clamp( static void ggml_compute_forward_rope_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 6); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12099,12 +11866,12 @@ static void ggml_compute_forward_rope_f32( float freq_base; float freq_scale; - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - const int n_ctx = ((int32_t *) src1->data)[3]; - memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); assert(n_past >= 0); @@ -12219,10 +11986,7 @@ static void ggml_compute_forward_rope_f32( static void ggml_compute_forward_rope_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 6); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12231,12 +11995,12 @@ static void ggml_compute_forward_rope_f16( float freq_base; float freq_scale; - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - const int n_ctx = ((int32_t *) src1->data)[3]; - memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); assert(n_past >= 0); @@ -12351,16 +12115,15 @@ static void ggml_compute_forward_rope_f16( static void ggml_compute_forward_rope( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_f16(params, src0, src1, dst); + ggml_compute_forward_rope_f16(params, src0, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_f32(params, src0, src1, dst); + ggml_compute_forward_rope_f32(params, src0, dst); } break; default: { @@ -12374,10 +12137,7 @@ static void ggml_compute_forward_rope( static void ggml_compute_forward_rope_back_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12387,9 +12147,9 @@ static void ggml_compute_forward_rope_back_f32( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; assert(n_past >= 0); @@ -12473,10 +12233,7 @@ static void ggml_compute_forward_rope_back_f32( static void ggml_compute_forward_rope_back_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -12486,9 +12243,9 @@ static void ggml_compute_forward_rope_back_f16( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; assert(n_past >= 0); @@ -12572,16 +12329,15 @@ static void ggml_compute_forward_rope_back_f16( static void ggml_compute_forward_rope_back( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_back_f16(params, src0, src1, dst); + ggml_compute_forward_rope_back_f16(params, src0, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_back_f32(params, src0, src1, dst); + ggml_compute_forward_rope_back_f32(params, src0, dst); } break; default: { @@ -12778,7 +12534,7 @@ static void ggml_compute_forward_conv_1d_s1_ph( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - struct ggml_tensor * dst) { + struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { @@ -12981,7 +12737,7 @@ static void ggml_compute_forward_conv_1d_s2_ph( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - struct ggml_tensor * dst) { + struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { @@ -13001,14 +12757,13 @@ static void ggml_compute_forward_conv_1d_s2_ph( // ggml_compute_forward_conv_1d static void ggml_compute_forward_conv_1d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, - struct ggml_tensor * dst) { - const int32_t s0 = ((const int32_t*)(opt0->data))[0]; - const int32_t p0 = ((const int32_t*)(opt0->data))[1]; - const int32_t d0 = ((const int32_t*)(opt0->data))[2]; + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; GGML_ASSERT(d0 == 1); // dilation not supported GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported if (s0 == 1) { @@ -13026,7 +12781,6 @@ static void ggml_compute_forward_conv_2d_f16_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -13046,12 +12800,12 @@ static void ggml_compute_forward_conv_2d_f16_f32( // size of the convolution row - the kernel size unrolled across all channels const int ew0 = nk0*nk1*ne02; - const int32_t s0 = ((const int32_t*)(opt0->data))[0]; - const int32_t s1 = ((const int32_t*)(opt0->data))[1]; - const int32_t p0 = ((const int32_t*)(opt0->data))[2]; - const int32_t p1 = ((const int32_t*)(opt0->data))[3]; - const int32_t d0 = ((const int32_t*)(opt0->data))[4]; - const int32_t d1 = ((const int32_t*)(opt0->data))[5]; + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); @@ -13123,17 +12877,15 @@ static void ggml_compute_forward_conv_2d( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - const struct ggml_tensor * opt0, - struct ggml_tensor * dst - ) { + struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst); + ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - //ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst); + //ggml_compute_forward_conv_2d_f32(params, src0, src1, dst); GGML_ASSERT(false); } break; default: @@ -13198,12 +12950,11 @@ static void ggml_compute_forward_pool_1d_sk_p0( // ggml_compute_forward_pool_1d static void ggml_compute_forward_pool_1d( - const struct ggml_compute_params* params, - const struct ggml_tensor* src0, - const struct ggml_tensor* opt0, - struct ggml_tensor* dst) { - GGML_ASSERT(opt0->ne[0] == 4); - const int* opts = (const int*)opt0->data; + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + const int32_t* opts = (const int32_t*)dst->op_params; enum ggml_op_pool op = opts[0]; const int k0 = opts[1]; const int s0 = opts[2]; @@ -13217,12 +12968,12 @@ static void ggml_compute_forward_pool_1d( // ggml_compute_forward_pool_2d_sk_p0 static void ggml_compute_forward_pool_2d_sk_p0( - const struct ggml_compute_params * params, - const enum ggml_op_pool op, - const struct ggml_tensor * src, - const int k0, - const int k1, - struct ggml_tensor * dst) { + const struct ggml_compute_params * params, + const enum ggml_op_pool op, + const struct ggml_tensor * src, + const int k0, + const int k1, + struct ggml_tensor * dst) { assert(src->type == GGML_TYPE_F32); assert(params->ith == 0); @@ -13282,12 +13033,11 @@ static void ggml_compute_forward_pool_2d_sk_p0( // ggml_compute_forward_pool_2d static void ggml_compute_forward_pool_2d( - const struct ggml_compute_params * params, - const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, - struct ggml_tensor * dst) { - GGML_ASSERT(opt0->ne[0] == 7); - const int* opts = (const int*)opt0->data; + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + const int32_t * opts = (const int32_t *)dst->op_params; enum ggml_op_pool op = opts[0]; const int k0 = opts[1]; const int k1 = opts[2]; @@ -13312,7 +13062,7 @@ static void ggml_compute_forward_flash_attn_f32( const struct ggml_tensor * k, const struct ggml_tensor * v, const bool masked, - struct ggml_tensor * dst) { + struct ggml_tensor * dst) { int64_t t0 = ggml_perf_time_us(); UNUSED(t0); @@ -13490,7 +13240,7 @@ static void ggml_compute_forward_flash_attn_f16( const struct ggml_tensor * k, const struct ggml_tensor * v, const bool masked, - struct ggml_tensor * dst) { + struct ggml_tensor * dst) { int64_t t0 = ggml_perf_time_us(); UNUSED(t0); @@ -14255,7 +14005,6 @@ static void ggml_compute_forward_flash_attn_back( static void ggml_compute_forward_win_part_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -14264,9 +14013,9 @@ static void ggml_compute_forward_win_part_f32( GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - const int32_t nep0 = ((const int32_t *)(opt0->data))[0]; - const int32_t nep1 = ((const int32_t *)(opt0->data))[1]; - const int32_t w = ((const int32_t *)(opt0->data))[2]; + const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t w = ((const int32_t *)(dst->op_params))[2]; assert(ne00 == ne0); assert(ne3 == nep0*nep1); @@ -14300,12 +14049,11 @@ static void ggml_compute_forward_win_part_f32( static void ggml_compute_forward_win_part( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_win_part_f32(params, src0, opt0, dst); + ggml_compute_forward_win_part_f32(params, src0, dst); } break; default: { @@ -14319,7 +14067,6 @@ static void ggml_compute_forward_win_part( static void ggml_compute_forward_win_unpart_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -14328,7 +14075,7 @@ static void ggml_compute_forward_win_unpart_f32( GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - const int32_t w = ((const int32_t *)(opt0->data))[0]; + const int32_t w = ((const int32_t *)(dst->op_params))[0]; // padding const int px = (w - ne1%w)%w; @@ -14362,12 +14109,11 @@ static void ggml_compute_forward_win_unpart_f32( static void ggml_compute_forward_win_unpart( const struct ggml_compute_params * params, const struct ggml_tensor * src0, - const struct ggml_tensor * opt0, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_win_unpart_f32(params, src0, opt0, dst); + ggml_compute_forward_win_unpart_f32(params, src0, dst); } break; default: { @@ -14886,7 +14632,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_ACC: { - ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_SUB: { @@ -15006,7 +14752,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_SET: { - ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_CPY: { @@ -15046,11 +14792,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_DIAG_MASK_INF: { - ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor); } break; case GGML_OP_DIAG_MASK_ZERO: { - ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor); } break; case GGML_OP_SOFT_MAX: { @@ -15062,35 +14808,35 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_ROPE: { - ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_rope(params, tensor->src[0], tensor); } break; case GGML_OP_ROPE_BACK: { - ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_rope_back(params, tensor->src[0], tensor); } break; case GGML_OP_ALIBI: { - ggml_compute_forward_alibi(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_alibi(params, tensor->src[0], tensor); } break; case GGML_OP_CLAMP: { - ggml_compute_forward_clamp(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_clamp(params, tensor->src[0], tensor); } break; case GGML_OP_CONV_1D: { - ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_CONV_2D: { - ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_POOL_1D: { - ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_pool_1d(params, tensor->src[0], tensor); } break; case GGML_OP_POOL_2D: { - ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor); + ggml_compute_forward_pool_2d(params, tensor->src[0], tensor); } break; case GGML_OP_FLASH_ATTN: { @@ -15112,40 +14858,45 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_WIN_PART: { - ggml_compute_forward_win_part(params, tensor->src[0], tensor->src[2], tensor); + ggml_compute_forward_win_part(params, tensor->src[0], tensor); } break; case GGML_OP_WIN_UNPART: { - ggml_compute_forward_win_unpart(params, tensor->src[0], tensor->src[2], tensor); + ggml_compute_forward_win_unpart(params, tensor->src[0], tensor); } break; case GGML_OP_MAP_UNARY: { - const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->src[2]->data); + ggml_unary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun); } break; case GGML_OP_MAP_BINARY: { - const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->src[2]->data); + ggml_binary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun); } break; case GGML_OP_MAP_CUSTOM1: { - const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->src[2]->data); + ggml_custom1_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); ggml_compute_forward_map_custom1(params, tensor->src[0], tensor, fun); } break; case GGML_OP_MAP_CUSTOM2: { - const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->src[2]->data); + ggml_custom2_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor, fun); } break; case GGML_OP_MAP_CUSTOM3: { - const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->src[2]->data); - ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[3], tensor, fun); + ggml_custom3_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor, fun); } break; case GGML_OP_CROSS_ENTROPY_LOSS: @@ -15209,12 +14960,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); } if (src1->grad) { - GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5); - GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32); - const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0]; - const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1]; - const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2]; - const size_t offset = (( int32_t * ) tensor->src[2]->data)[3]; + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, tensor->grad, @@ -15522,12 +15271,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_SET: { - GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5); - GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32); - const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0]; - const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1]; - const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2]; - const size_t offset = (( int32_t * ) tensor->src[2]->data)[3]; + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; struct ggml_tensor * tensor_grad_view = NULL; @@ -15604,8 +15351,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad) { size_t offset; - GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->src[2])); - memcpy(&offset, tensor->src[2]->data, sizeof(offset)); + memcpy(&offset, tensor->op_params, sizeof(offset)); size_t nb1 = tensor->nb[1]; size_t nb2 = tensor->nb[2]; @@ -15632,7 +15378,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - int32_t * axes = (int32_t *) tensor->src[2]->data; + int32_t * axes = (int32_t *) tensor->op_params; int axis0 = axes[0] & 0x3; int axis1 = axes[1] & 0x3; int axis2 = axes[2] & 0x3; @@ -15688,33 +15434,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); - const int n_past = ((int32_t *) src1->data)[0]; + const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), inplace); } - if (src1->grad) { - // noop - } } break; case GGML_OP_DIAG_MASK_ZERO: { // necessary for llama if (src0->grad) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 2); - const int n_past = ((int32_t *) src1->data)[0]; + const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), inplace); } - if (src1->grad) { - // noop - } } break; case GGML_OP_SOFT_MAX: { @@ -15735,33 +15471,28 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 6); - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; + const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + const int n_ctx = ((int32_t *) tensor->op_params)[3]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_rope_back(ctx, tensor->grad, n_past, n_dims, - mode), + mode, + n_ctx), inplace); } - if (src1->grad) { - // noop - } } break; case GGML_OP_ROPE_BACK: { if (src0->grad) { - assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 3); - const int n_past = ((int32_t *) src1->data)[0]; - const int n_dims = ((int32_t *) src1->data)[1]; - const int mode = ((int32_t *) src1->data)[2]; - const int n_ctx = ((int32_t *) src1->data)[3]; + const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + const int n_ctx = ((int32_t *) tensor->op_params)[3]; src0->grad = ggml_add_impl(ctx, src0->grad, ggml_rope(ctx, @@ -15772,9 +15503,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor n_ctx), inplace); } - if (src1->grad) { - // noop - } } break; case GGML_OP_ALIBI: { @@ -16538,10 +16266,10 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS_BACK: case GGML_OP_DIAG: - case GGML_OP_DIAG_MASK_ZERO: { n_tasks = 1; } break; + case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -16988,7 +16716,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { fwrite(&nb, sizeof(uint64_t), 1, fout); } - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); // dump the data // TODO: pad this to 32 byte boundary @@ -17021,7 +16750,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { fwrite(&nb, sizeof(uint64_t), 1, fout); } - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); + fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); // output the op arguments { @@ -17202,7 +16932,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** tensor->op = (enum ggml_op) op; - memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; + memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; + memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS; tensor->data = (void *) ptr; @@ -17247,7 +16978,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** nb[j] = nb_cur; } - const char * ptr_name = ptr; ptr += GGML_MAX_NAME; + const char * ptr_name = ptr; ptr += GGML_MAX_NAME; + const char * ptr_op_params = ptr; ptr += GGML_MAX_OP_PARAMS; const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t); @@ -17284,8 +17016,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** { tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); - uint64_t offs; - memcpy(&offs, args[2]->data, sizeof(offs)); + size_t offs; + memcpy(&offs, ptr_op_params, sizeof(offs)); tensor->data = ((char *) tensor->data) + offs; } break; @@ -17305,7 +17037,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** } break; } - memcpy(tensor->name, ptr_name, GGML_MAX_NAME); + memcpy(tensor->name, ptr_name, GGML_MAX_NAME); + memcpy(tensor->op_params, ptr_op_params, GGML_MAX_OP_PARAMS); for (int j = 0; j < GGML_MAX_DIMS; ++j) { tensor->nb[j] = nb[j];