From: Georgi Gerganov Date: Sat, 20 May 2023 12:59:34 +0000 (+0300) Subject: ggml : sync llama.cpp - CUDA improvements + ggml minor fixes X-Git-Tag: upstream/0.0.1642~1462 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=1ec82cbac5544c071619b22e948d9ec7ccfb5340;p=pkg%2Fggml%2Fsources%2Fggml ggml : sync llama.cpp - CUDA improvements + ggml minor fixes --- diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index eb9f0df5..35d2e457 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -42,19 +42,19 @@ typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, #define QK4_0 32 #define QR4_0 2 typedef struct { - float d; // delta + half d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 #define QR4_1 2 typedef struct { - float d; // delta - float m; // min + half d; // delta + half m; // min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 #define QR5_0 2 @@ -78,12 +78,23 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + #define QK8_0 32 #define QR8_0 1 typedef struct { - float d; // delta + half d; // delta int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); -#define CUDA_DMMV_BLOCK_SIZE 32 +#define CUDA_MUL_BLOCK_SIZE 256 +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec + +static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= kx) { + return; + } + dst[i] = x[i] * y[i%ky]; +} static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ const block_q4_0 * x = (const block_q4_0 *) vx; @@ -170,104 +181,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, v1 = __half2float(x[ib + 1]); } -static __global__ void dequantize_block_q4_0(const void * vx, float * y) { - static const int qk = QK4_0; - - const block_q4_0 * x = (const block_q4_0 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0xf) - 8; - const int x1 = (x[i].qs[j] >> 4) - 8; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } -} - -static __global__ void dequantize_block_q4_1(const void * vx, float * y) { - static const int qk = QK4_1; +template +static __global__ void dequantize_block(const void * vx, float * y, const int k) { + const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; - const block_q4_1 * x = (const block_q4_1 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - const float m = x[i].m; - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0xf); - const int x1 = (x[i].qs[j] >> 4); - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; + if (i >= k) { + return; } -} - -static __global__ void dequantize_block_q5_0(const void * vx, float * y) { - static const int qk = QK5_0; - - const block_q5_0 * x = (const block_q5_0 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } -} - -static __global__ void dequantize_block_q5_1(const void * vx, float * y) { - static const int qk = QK5_1; - - const block_q5_1 * x = (const block_q5_1 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; - const float m = x[i].m; - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int x0 = (x[i].qs[j] & 0xf) | xh_0; - const int x1 = (x[i].qs[j] >> 4) | xh_1; - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } -} - -static __global__ void dequantize_block_q8_0(const void * vx, float * y) { - static const int qk = QK8_0; - - const block_q8_0 * x = (const block_q8_0 *) vx; - - const int i = blockIdx.x; - - const float d = x[i].d; + const int ib = i/qk; // block index + const int iqs = (i%qk)/qr; // quant index + const int iybs = i - i%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; - for (int j = 0; j < qk; ++j) { - y[i*qk + j] = x[i].qs[j]*d; - } + // dequantize + float & v0 = y[iybs + iqs + 0]; + float & v1 = y[iybs + iqs + y_offset]; + dequantize_kernel(vx, ib, iqs, v0, v1); } template @@ -308,29 +238,34 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } } -static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_0; - dequantize_block_q4_0<<>>(vx, y); +static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { + const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; + mul_f32<<>>(x, y, dst, kx, ky); +} + +static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK4_1; - dequantize_block_q4_1<<>>(vx, y); +static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK5_0; - dequantize_block_q5_0<<>>(vx, y); +static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK5_1; - dequantize_block_q5_1<<>>(vx, y); +static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); } -static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { - const int nb = k / QK8_0; - dequantize_block_q8_0<<>>(vx, y); +static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); } static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { @@ -363,17 +298,9 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f <<>>(vx, y, dst, ncols); } -// TODO: optimize -static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { - const half * x = (const half *) vx; - - const int i = blockIdx.x; - - y[i] = __half2float(x[i]); -} - -static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) { - convert_fp16_to_fp32<<>>(x, y); +static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<32, 1, convert_f16><<>>(vx, y, k); } static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { @@ -555,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor } } +static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA); + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[2]; + const int64_t ne0 = ne00 * ne01 * ne02 * ne03; + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + size_t x_size, d_size; + + float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0 + float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted. + float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const int i0 = i03*ne02 + i02; + float * c_X2 = d_X + i0*ne01*ne00; + float * c_D2 = d_D + i0*ne01*ne00; + + cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS]; + cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS]; + cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS]; + + // copy src0 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2)); + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + + // wait for data + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + + for (int64_t i01 = 0; i01 < ne01; i01++) { + const int64_t i13 = i03%ne13; + const int64_t i12 = i02%ne12; + const int64_t i11 = i01%ne11; + const int i1 = i13*ne12*ne11 + i12*ne11 + i11; + + float * c_X1 = c_X2 + i01*ne00; + float * c_Y = d_Y + i1*ne10; + float * c_D1 = c_D2 + i01*ne00; + + // compute + mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream); + CUDA_CHECK(cudaGetLastError()); + } + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream)); + } + } + CUDA_CHECK(cudaDeviceSynchronize()); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_D, d_size); +} + static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -812,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor ggml_cuda_pool_free(d_Q, q_size); } +void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_cuda_mul_f32(src0, src1, dst); +} + bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { const int64_t ne10 = src1->ne[0]; @@ -885,14 +878,48 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) { const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); size_t q_size; - char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); + char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); cudaStream_t cudaStream2 = g_cudaStreams2[0]; // copy tensor to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); - CUDA_CHECK(cudaDeviceSynchronize()); + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + int i = i3*ne2 + i2; + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2)); + } + } - tensor->data = d_Q; + tensor->data = dst; tensor->backend = GGML_BACKEND_CUDA; } + +void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) { + FILE * fp = fopen(fname, "rb"); + + const size_t size = ggml_nbytes(tensor); + + void * buf; + CUDA_CHECK(cudaMalloc(&buf, size)); + void * buf_host = malloc(size); + +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, SEEK_SET); +#else + int ret = fseek(fp, (long) offset, SEEK_SET); +#endif + GGML_ASSERT(ret == 0); // same + + size_t ret2 = fread(buf_host, size, 1, fp); + if (ret2 != 1) { + fprintf(stderr, "unexpectedly reached end of file"); + exit(1); + } + + cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + tensor->data = buf; + free(buf_host); + fclose(fp); +} diff --git a/src/ggml-cuda.h b/src/ggml-cuda.h index 4e2c2428..6a04dde6 100644 --- a/src/ggml-cuda.h +++ b/src/ggml-cuda.h @@ -6,6 +6,7 @@ extern "C" { void ggml_init_cublas(void); +void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); @@ -15,6 +16,7 @@ void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); +void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset); #ifdef __cplusplus } diff --git a/src/ggml.c b/src/ggml.c index 6b48852d..f4a5a8d9 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -3779,6 +3779,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g (t1->ne[3]%t0->ne[3] == 0); } +static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1); +} + static inline int ggml_up32(int n) { return (n + 31) & ~31; } @@ -4661,11 +4667,15 @@ struct ggml_tensor * ggml_mul_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); is_node = true; } @@ -6240,7 +6250,7 @@ struct ggml_tensor * ggml_alibi( return result; } -// ggml_alibi +// ggml_clamp struct ggml_tensor * ggml_clamp( struct ggml_context * ctx, @@ -6257,10 +6267,15 @@ struct ggml_tensor * ggml_clamp( // TODO: when implement backward, fix this: struct ggml_tensor * result = ggml_view_tensor(ctx, a); + ggml_scratch_save(ctx); + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); + ((float *) b->data)[0] = min; ((float *) b->data)[1] = max; + ggml_scratch_load(ctx); + result->op = GGML_OP_CLAMP; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src0 = a; @@ -7995,7 +8010,7 @@ static void ggml_compute_forward_mul_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -8003,10 +8018,25 @@ static void ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; - const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; +#ifdef GGML_USE_CUBLAS + if (src1->backend == GGML_BACKEND_CUDA) { + if (ith == 0) { + ggml_cuda_mul(src0, src1, dst); + } + return; + } +#endif + + const int64_t nr = ggml_nrows(src0); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; @@ -8025,44 +8055,51 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(ne00 == ne10); if (nb10 == sizeof(float)) { - for (int ir = ith; ir < nr; ir += nth) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); #ifdef GGML_USE_ACCELERATE UNUSED(ggml_vec_mul_f32); - vDSP_vmul( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); + vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); #else - ggml_vec_mul_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif // } // } } } else { // src1 is not contiguous - for (int ir = ith; ir < nr; ir += nth) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); } @@ -10784,7 +10821,6 @@ static void ggml_compute_forward_alibi_f32( } } - static void ggml_compute_forward_alibi_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, @@ -10880,7 +10916,7 @@ static void ggml_compute_forward_alibi( } -// ggml_compute_forward_alibi +// ggml_compute_forward_clamp static void ggml_compute_forward_clamp_f32( const struct ggml_compute_params * params, @@ -10898,7 +10934,6 @@ static void ggml_compute_forward_clamp_f32( const int min = ((float *) src1->data)[0]; const int max = ((float *) src1->data)[1]; - const int ith = params->ith; const int nth = params->nth; @@ -10915,16 +10950,15 @@ static void ggml_compute_forward_clamp_f32( GGML_ASSERT(nb00 == sizeof(float)); for (int j = ith; j < n; j += nth) { - float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); float * src0_ptr = (float *) ((char *) src0->data + j*nb01); - for (int i = 0; i < nc; i++) { + for (int i = 0; i < nc; i++) { dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); } } } - static void ggml_compute_forward_clamp( const struct ggml_compute_params * params, const struct ggml_tensor * src0,