]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp - CUDA improvements + ggml minor fixes
authorGeorgi Gerganov <redacted>
Sat, 20 May 2023 12:59:34 +0000 (15:59 +0300)
committerGeorgi Gerganov <redacted>
Sat, 20 May 2023 12:59:34 +0000 (15:59 +0300)
src/ggml-cuda.cu
src/ggml-cuda.h
src/ggml.c

index eb9f0df5aee3d32998fee6a345404b40b9bc8a6f..35d2e457cbf8487fd967b6459925a85d1ab0282b 100644 (file)
@@ -42,19 +42,19 @@ typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y,
 #define QK4_0 32
 #define QR4_0 2
 typedef struct {
-    float   d;              // delta
+    half    d;              // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
 } block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
 
 #define QK4_1 32
 #define QR4_1 2
 typedef struct {
-    float   d;              // delta
-    float   m;              // min
+    half    d;              // delta
+    half    m;              // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
 
 #define QK5_0 32
 #define QR5_0 2
@@ -78,12 +78,23 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
 #define QK8_0 32
 #define QR8_0 1
 typedef struct {
-    float   d;              // delta
+    half    d;              // delta
     int8_t  qs[QK8_0];      // quants
 } block_q8_0;
-static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
-#define CUDA_DMMV_BLOCK_SIZE 32
+#define CUDA_MUL_BLOCK_SIZE 256
+#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
+
+static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= kx) {
+        return;
+    }
+    dst[i] = x[i] * y[i%ky];
+}
 
 static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
     const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -170,104 +181,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
     v1 = __half2float(x[ib + 1]);
 }
 
-static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
-    static const int qk = QK4_0;
-
-    const block_q4_0 * x = (const block_q4_0 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
-
-    for (int j = 0; j < qk/2; ++j) {
-        const int x0 = (x[i].qs[j] & 0xf) - 8;
-        const int x1 = (x[i].qs[j] >>  4) - 8;
-
-        y[i*qk + j + 0   ] = x0*d;
-        y[i*qk + j + qk/2] = x1*d;
-    }
-}
-
-static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
-    static const int qk = QK4_1;
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_block(const void * vx, float * y, const int k) {
+    const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
 
-    const block_q4_1 * x = (const block_q4_1 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
-    const float m = x[i].m;
-
-    for (int j = 0; j < qk/2; ++j) {
-        const int x0 = (x[i].qs[j] & 0xf);
-        const int x1 = (x[i].qs[j] >>  4);
-
-        y[i*qk + j + 0   ] = x0*d + m;
-        y[i*qk + j + qk/2] = x1*d + m;
+    if (i >= k) {
+        return;
     }
-}
-
-static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
-    static const int qk = QK5_0;
-
-    const block_q5_0 * x = (const block_q5_0 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
-
-    uint32_t qh;
-    memcpy(&qh, x[i].qh, sizeof(qh));
 
-    for (int j = 0; j < qk/2; ++j) {
-        const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-        const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-        const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
-        const int32_t x1 = ((x[i].qs[j] >>  4) | xh_1) - 16;
-
-        y[i*qk + j + 0   ] = x0*d;
-        y[i*qk + j + qk/2] = x1*d;
-    }
-}
-
-static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
-    static const int qk = QK5_1;
-
-    const block_q5_1 * x = (const block_q5_1 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
-    const float m = x[i].m;
-
-    uint32_t qh;
-    memcpy(&qh, x[i].qh, sizeof(qh));
-
-    for (int j = 0; j < qk/2; ++j) {
-        const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-        const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-        const int x0 = (x[i].qs[j] & 0xf) | xh_0;
-        const int x1 = (x[i].qs[j] >>  4) | xh_1;
-
-        y[i*qk + j + 0   ] = x0*d + m;
-        y[i*qk + j + qk/2] = x1*d + m;
-    }
-}
-
-static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
-    static const int qk = QK8_0;
-
-    const block_q8_0 * x = (const block_q8_0 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
+    const int ib = i/qk; // block index
+    const int iqs = (i%qk)/qr; // quant index
+    const int iybs = i - i%qk; // y block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
 
-    for (int j = 0; j < qk; ++j) {
-        y[i*qk + j] = x[i].qs[j]*d;
-    }
+    // dequantize
+    float & v0 = y[iybs + iqs + 0];
+    float & v1 = y[iybs + iqs + y_offset];
+    dequantize_kernel(vx, ib, iqs, v0, v1);
 }
 
 template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -308,29 +238,34 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
     }
 }
 
-static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK4_0;
-    dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
+static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+    const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
+    mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
+}
+
+static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK4_1;
-    dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
+static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK5_0;
-    dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
+static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK5_1;
-    dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
+static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK8_0;
-    dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
+static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -363,17 +298,9 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
         <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
 }
 
-// TODO: optimize
-static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
-    const half * x = (const half *) vx;
-
-    const int i = blockIdx.x;
-
-    y[i] = __half2float(x[i]);
-}
-
-static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
-    convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
+static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -555,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
     }
 }
 
+static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[2];
+    const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+    size_t x_size, d_size;
+
+    float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
+    float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
+    float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            const int i0 = i03*ne02 + i02;
+            float * c_X2 = d_X + i0*ne01*ne00;
+            float * c_D2 = d_D + i0*ne01*ne00;
+
+            cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
+            cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
+            cudaEvent_t  cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
+
+            // copy src0 to device
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
+            CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+
+            // wait for data
+            CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                const int64_t i13 = i03%ne13;
+                const int64_t i12 = i02%ne12;
+                const int64_t i11 = i01%ne11;
+                const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
+
+                float * c_X1 = c_X2 + i01*ne00;
+                float * c_Y = d_Y + i1*ne10;
+                float * c_D1 = c_D2 + i01*ne00;
+
+                // compute
+                mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
+                CUDA_CHECK(cudaGetLastError());
+            }
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
+        }
+    }
+    CUDA_CHECK(cudaDeviceSynchronize());
+    ggml_cuda_pool_free(d_X, x_size);
+    ggml_cuda_pool_free(d_D, d_size);
+}
+
 static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
@@ -812,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
     ggml_cuda_pool_free(d_Q, q_size);
 }
 
+void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_mul_f32(src0, src1, dst);
+}
+
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
     const int64_t ne10 = src1->ne[0];
 
@@ -885,14 +878,48 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
     const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
 
     size_t q_size;
-    char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
+    char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
 
     cudaStream_t cudaStream2 = g_cudaStreams2[0];
 
     // copy tensor to device
-    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
-    CUDA_CHECK(cudaDeviceSynchronize());
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            int i = i3*ne2 + i2;
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
+        }
+    }
 
-    tensor->data = d_Q;
+    tensor->data = dst;
     tensor->backend = GGML_BACKEND_CUDA;
 }
+
+void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
+    FILE * fp = fopen(fname, "rb");
+
+    const size_t size = ggml_nbytes(tensor);
+
+    void * buf;
+    CUDA_CHECK(cudaMalloc(&buf, size));
+    void * buf_host = malloc(size);
+
+#ifdef _WIN32
+    int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
+#else
+    int ret = fseek(fp, (long) offset, SEEK_SET);
+#endif
+    GGML_ASSERT(ret == 0); // same
+
+    size_t ret2 = fread(buf_host, size, 1, fp);
+    if (ret2 != 1) {
+        fprintf(stderr, "unexpectedly reached end of file");
+        exit(1);
+    }
+
+    cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+    cudaDeviceSynchronize();
+
+    tensor->data = buf;
+    free(buf_host);
+    fclose(fp);
+}
index 4e2c24283ccf4b29d180d3f3f0625183f128fb88..6a04dde6c37a9d3179747575ecb2dee455a29720 100644 (file)
@@ -6,6 +6,7 @@ extern "C" {
 
 void   ggml_init_cublas(void);
 
+void   ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
@@ -15,6 +16,7 @@ void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
 void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
+void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
 
 #ifdef  __cplusplus
 }
index 6b48852d8093ac419301a5173ba6d5962d4ee675..f4a5a8d91fc9bbe9e26703891a79e0e218dfa9db 100644 (file)
@@ -3779,6 +3779,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g
         (t1->ne[3]%t0->ne[3] == 0);
 }
 
+static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
+}
+
 static inline int ggml_up32(int n) {
     return (n + 31) & ~31;
 }
@@ -4661,11 +4667,15 @@ struct ggml_tensor * ggml_mul_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
+    // TODO: support less-strict constraint
+    //       GGML_ASSERT(ggml_can_repeat(b, a));
+    GGML_ASSERT(ggml_can_repeat_rows(b, a));
 
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
+        // TODO: support backward pass for broadcasting
+        GGML_ASSERT(ggml_are_same_shape(a, b));
         is_node = true;
     }
 
@@ -6240,7 +6250,7 @@ struct ggml_tensor * ggml_alibi(
     return result;
 }
 
-// ggml_alibi
+// ggml_clamp
 
 struct ggml_tensor * ggml_clamp(
         struct ggml_context * ctx,
@@ -6257,10 +6267,15 @@ struct ggml_tensor * ggml_clamp(
     // TODO: when implement backward, fix this:
     struct ggml_tensor * result = ggml_view_tensor(ctx, a);
 
+    ggml_scratch_save(ctx);
+
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+
     ((float *) b->data)[0] = min;
     ((float *) b->data)[1] = max;
 
+    ggml_scratch_load(ctx);
+
     result->op   = GGML_OP_CLAMP;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src0 = a;
@@ -7995,7 +8010,7 @@ static void ggml_compute_forward_mul_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -8003,10 +8018,25 @@ static void ggml_compute_forward_mul_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
+#ifdef GGML_USE_CUBLAS
+    if (src1->backend == GGML_BACKEND_CUDA) {
+        if (ith == 0) {
+            ggml_cuda_mul(src0, src1, dst);
+        }
+        return;
+    }
+#endif
+
+    const int64_t nr = ggml_nrows(src0);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
@@ -8025,44 +8055,51 @@ static void ggml_compute_forward_mul_f32(
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(ne00 == ne10);
 
     if (nb10 == sizeof(float)) {
-        for (int ir = ith; ir < nr; ir += nth) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
 
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
 #ifdef GGML_USE_ACCELERATE
             UNUSED(ggml_vec_mul_f32);
 
-            vDSP_vmul(
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                    ne0);
+            vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr,  1, ne00);
 #else
-            ggml_vec_mul_f32(ne0,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+            ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
 #endif
                 // }
             // }
         }
     } else {
         // src1 is not contiguous
-        for (int ir = ith; ir < nr; ir += nth) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; i0++) {
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
             }
@@ -10784,7 +10821,6 @@ static void ggml_compute_forward_alibi_f32(
     }
 }
 
-
 static void ggml_compute_forward_alibi_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -10880,7 +10916,7 @@ static void ggml_compute_forward_alibi(
 }
 
 
-// ggml_compute_forward_alibi
+// ggml_compute_forward_clamp
 
 static void ggml_compute_forward_clamp_f32(
         const struct ggml_compute_params * params,
@@ -10898,7 +10934,6 @@ static void ggml_compute_forward_clamp_f32(
     const int min = ((float *) src1->data)[0];
     const int max = ((float *) src1->data)[1];
 
-
     const int ith = params->ith;
     const int nth = params->nth;
 
@@ -10915,16 +10950,15 @@ static void ggml_compute_forward_clamp_f32(
     GGML_ASSERT(nb00 == sizeof(float));
 
     for (int j = ith; j < n; j += nth) {
-        float * dst_ptr  = (float *) ((char *) dst->data  + j*nb1);
+        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
         float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
-        for (int i = 0; i < nc; i++) {
 
+        for (int i = 0; i < nc; i++) {
             dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
         }
     }
 }
 
-
 static void ggml_compute_forward_clamp(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,