]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : sync latest ggml repo
authorGeorgi Gerganov <redacted>
Sat, 20 May 2023 15:56:30 +0000 (18:56 +0300)
committerGeorgi Gerganov <redacted>
Sat, 20 May 2023 15:56:30 +0000 (18:56 +0300)
- new Q4 and Q8 quantization
- updated CUDA

examples/common.cpp
ggml-cuda.cu
ggml-cuda.h
ggml.c
ggml.h

index a8461fb4d069b77574a925cb800b310fbea4d0ea..76da30d9d02b270a16eb1034510dae9c064907bf 100644 (file)
@@ -26,7 +26,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         } else if (arg == "-n" || arg == "--n_predict") {
             params.n_predict = std::stoi(argv[++i]);
         } else if (arg == "--top_k") {
-            params.top_k = std::stoi(argv[++i]);
+            params.top_k = std::max(1, std::stoi(argv[++i]));
         } else if (arg == "--top_p") {
             params.top_p = std::stof(argv[++i]);
         } else if (arg == "--temp") {
@@ -259,6 +259,7 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
                 if (it != vocab.token_to_id.end()) {
                     tokens.push_back(it->second);
                     i = j;
+                    j = n;
                     break;
                 }
                 --j;
index eb9f0df5aee3d32998fee6a345404b40b9bc8a6f..35d2e457cbf8487fd967b6459925a85d1ab0282b 100644 (file)
@@ -42,19 +42,19 @@ typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y,
 #define QK4_0 32
 #define QR4_0 2
 typedef struct {
-    float   d;              // delta
+    half    d;              // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
 } block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
 
 #define QK4_1 32
 #define QR4_1 2
 typedef struct {
-    float   d;              // delta
-    float   m;              // min
+    half    d;              // delta
+    half    m;              // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
 
 #define QK5_0 32
 #define QR5_0 2
@@ -78,12 +78,23 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
 #define QK8_0 32
 #define QR8_0 1
 typedef struct {
-    float   d;              // delta
+    half    d;              // delta
     int8_t  qs[QK8_0];      // quants
 } block_q8_0;
-static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
-#define CUDA_DMMV_BLOCK_SIZE 32
+#define CUDA_MUL_BLOCK_SIZE 256
+#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
+
+static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= kx) {
+        return;
+    }
+    dst[i] = x[i] * y[i%ky];
+}
 
 static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
     const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -170,104 +181,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
     v1 = __half2float(x[ib + 1]);
 }
 
-static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
-    static const int qk = QK4_0;
-
-    const block_q4_0 * x = (const block_q4_0 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
-
-    for (int j = 0; j < qk/2; ++j) {
-        const int x0 = (x[i].qs[j] & 0xf) - 8;
-        const int x1 = (x[i].qs[j] >>  4) - 8;
-
-        y[i*qk + j + 0   ] = x0*d;
-        y[i*qk + j + qk/2] = x1*d;
-    }
-}
-
-static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
-    static const int qk = QK4_1;
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
+static __global__ void dequantize_block(const void * vx, float * y, const int k) {
+    const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
 
-    const block_q4_1 * x = (const block_q4_1 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
-    const float m = x[i].m;
-
-    for (int j = 0; j < qk/2; ++j) {
-        const int x0 = (x[i].qs[j] & 0xf);
-        const int x1 = (x[i].qs[j] >>  4);
-
-        y[i*qk + j + 0   ] = x0*d + m;
-        y[i*qk + j + qk/2] = x1*d + m;
+    if (i >= k) {
+        return;
     }
-}
-
-static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
-    static const int qk = QK5_0;
-
-    const block_q5_0 * x = (const block_q5_0 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
-
-    uint32_t qh;
-    memcpy(&qh, x[i].qh, sizeof(qh));
 
-    for (int j = 0; j < qk/2; ++j) {
-        const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-        const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-        const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
-        const int32_t x1 = ((x[i].qs[j] >>  4) | xh_1) - 16;
-
-        y[i*qk + j + 0   ] = x0*d;
-        y[i*qk + j + qk/2] = x1*d;
-    }
-}
-
-static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
-    static const int qk = QK5_1;
-
-    const block_q5_1 * x = (const block_q5_1 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
-    const float m = x[i].m;
-
-    uint32_t qh;
-    memcpy(&qh, x[i].qh, sizeof(qh));
-
-    for (int j = 0; j < qk/2; ++j) {
-        const uint8_t xh_0 = ((qh >> (j +  0)) << 4) & 0x10;
-        const uint8_t xh_1 = ((qh >> (j + 12))     ) & 0x10;
-
-        const int x0 = (x[i].qs[j] & 0xf) | xh_0;
-        const int x1 = (x[i].qs[j] >>  4) | xh_1;
-
-        y[i*qk + j + 0   ] = x0*d + m;
-        y[i*qk + j + qk/2] = x1*d + m;
-    }
-}
-
-static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
-    static const int qk = QK8_0;
-
-    const block_q8_0 * x = (const block_q8_0 *) vx;
-
-    const int i = blockIdx.x;
-
-    const float d = x[i].d;
+    const int ib = i/qk; // block index
+    const int iqs = (i%qk)/qr; // quant index
+    const int iybs = i - i%qk; // y block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
 
-    for (int j = 0; j < qk; ++j) {
-        y[i*qk + j] = x[i].qs[j]*d;
-    }
+    // dequantize
+    float & v0 = y[iybs + iqs + 0];
+    float & v1 = y[iybs + iqs + y_offset];
+    dequantize_kernel(vx, ib, iqs, v0, v1);
 }
 
 template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -308,29 +238,34 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
     }
 }
 
-static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK4_0;
-    dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
+static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+    const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
+    mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
+}
+
+static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK4_1;
-    dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
+static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK5_0;
-    dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
+static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK5_1;
-    dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
+static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
-static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
-    const int nb = k / QK8_0;
-    dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
+static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
 static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -363,17 +298,9 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
         <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
 }
 
-// TODO: optimize
-static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
-    const half * x = (const half *) vx;
-
-    const int i = blockIdx.x;
-
-    y[i] = __half2float(x[i]);
-}
-
-static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
-    convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
+static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
+    dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -555,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
     }
 }
 
+static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[2];
+    const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+    size_t x_size, d_size;
+
+    float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
+    float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
+    float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            const int i0 = i03*ne02 + i02;
+            float * c_X2 = d_X + i0*ne01*ne00;
+            float * c_D2 = d_D + i0*ne01*ne00;
+
+            cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
+            cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
+            cudaEvent_t  cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
+
+            // copy src0 to device
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
+            CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+
+            // wait for data
+            CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+
+            for (int64_t i01 = 0; i01 < ne01; i01++) {
+                const int64_t i13 = i03%ne13;
+                const int64_t i12 = i02%ne12;
+                const int64_t i11 = i01%ne11;
+                const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
+
+                float * c_X1 = c_X2 + i01*ne00;
+                float * c_Y = d_Y + i1*ne10;
+                float * c_D1 = c_D2 + i01*ne00;
+
+                // compute
+                mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
+                CUDA_CHECK(cudaGetLastError());
+            }
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
+        }
+    }
+    CUDA_CHECK(cudaDeviceSynchronize());
+    ggml_cuda_pool_free(d_X, x_size);
+    ggml_cuda_pool_free(d_D, d_size);
+}
+
 static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
@@ -812,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
     ggml_cuda_pool_free(d_Q, q_size);
 }
 
+void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_mul_f32(src0, src1, dst);
+}
+
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
     const int64_t ne10 = src1->ne[0];
 
@@ -885,14 +878,48 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
     const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
 
     size_t q_size;
-    char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
+    char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
 
     cudaStream_t cudaStream2 = g_cudaStreams2[0];
 
     // copy tensor to device
-    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
-    CUDA_CHECK(cudaDeviceSynchronize());
+    for (int64_t i3 = 0; i3 < ne3; i3++) {
+        for (int64_t i2 = 0; i2 < ne2; i2++) {
+            int i = i3*ne2 + i2;
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
+        }
+    }
 
-    tensor->data = d_Q;
+    tensor->data = dst;
     tensor->backend = GGML_BACKEND_CUDA;
 }
+
+void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
+    FILE * fp = fopen(fname, "rb");
+
+    const size_t size = ggml_nbytes(tensor);
+
+    void * buf;
+    CUDA_CHECK(cudaMalloc(&buf, size));
+    void * buf_host = malloc(size);
+
+#ifdef _WIN32
+    int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
+#else
+    int ret = fseek(fp, (long) offset, SEEK_SET);
+#endif
+    GGML_ASSERT(ret == 0); // same
+
+    size_t ret2 = fread(buf_host, size, 1, fp);
+    if (ret2 != 1) {
+        fprintf(stderr, "unexpectedly reached end of file");
+        exit(1);
+    }
+
+    cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+    cudaDeviceSynchronize();
+
+    tensor->data = buf;
+    free(buf_host);
+    fclose(fp);
+}
index 4e2c24283ccf4b29d180d3f3f0625183f128fb88..6a04dde6c37a9d3179747575ecb2dee455a29720 100644 (file)
@@ -6,6 +6,7 @@ extern "C" {
 
 void   ggml_init_cublas(void);
 
+void   ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
@@ -15,6 +16,7 @@ void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
 void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
+void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
 
 #ifdef  __cplusplus
 }
diff --git a/ggml.c b/ggml.c
index da3d914e4ef47837fb5455676e61e95a3aef2579..f4a5a8d91fc9bbe9e26703891a79e0e218dfa9db 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -512,7 +512,7 @@ static inline int hsum_i32_4(const __m128i a) {
     return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
 }
 
-#if __AVX2__ || __AVX512F__
+#if defined(__AVX2__) || defined(__AVX512F__)
 // spread 32 bits to 32 bytes { 0x00, 0xFF }
 static inline __m256i bytes_from_bits_32(const uint8_t * x) {
     uint32_t x32;
@@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
     return _mm256_cvtepi32_ps(summed_pairs);
 }
 
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-    // Get absolute values of x vectors
-    const __m256i ax = _mm256_sign_epi8(x, x);
-    // Sign the values of the y vectors
-    const __m256i sy = _mm256_sign_epi8(y, x);
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
 #if __AVXVNNI__
     const __m256i zero = _mm256_setzero_si256();
     const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
@@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
 #endif
 }
 
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+    const __m256i zero = _mm256_setzero_si256();
+    const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+    return _mm256_cvtepi32_ps(summed_pairs);
+#else
+    // Get absolute values of x vectors
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    // Sign the values of the y vectors
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
 static inline __m128i packNibbles( __m256i bytes )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
     return _mm256_cvtepi32_ps(summed_pairs);
 }
 
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+    const __m128i axl = _mm256_castsi256_si128(ax);
+    const __m128i axh = _mm256_extractf128_si256(ax, 1);
+    const __m128i syl = _mm256_castsi256_si128(sy);
+    const __m128i syh = _mm256_extractf128_si256(sy, 1);
+    // Perform multiplication and create 16-bit values
+    const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+    const __m128i doth = _mm_maddubs_epi16(axh, syh);
+    return sum_i16_pairs_float(doth, dotl);
+}
+
 // multiply int8_t, add results pairwise twice and return as float vector
 static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
     const __m128i xl = _mm256_castsi256_si128(x);
@@ -667,7 +688,7 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
 #endif // __AVX__ || __AVX2__ || __AVX512F__
 #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
 
-#if __ARM_NEON
+#if defined(__ARM_NEON)
 
 #if !defined(__aarch64__)
 
@@ -748,18 +769,18 @@ int32x4_t vcvtnq_s32_f32(float32x4_t v) {
 
 #define QK4_0 32
 typedef struct {
-    float   d;          // delta
+    ggml_fp16_t d;          // delta
     uint8_t qs[QK4_0 / 2];  // nibbles / quants
 } block_q4_0;
-static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
 
 #define QK4_1 32
 typedef struct {
-    float   d;          // delta
-    float   m;          // min
+    ggml_fp16_t d;          // delta
+    ggml_fp16_t m;          // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
-static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding");
+static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
 
 #define QK5_0 32
 typedef struct {
@@ -780,16 +801,16 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
 
 #define QK8_0 32
 typedef struct {
-    float   d;          // delta
-    int8_t  qs[QK8_0];  // quants
+    ggml_fp16_t d;         // delta
+    int8_t  qs[QK8_0];     // quants
 } block_q8_0;
-static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
+static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
 
 #define QK8_1 32
 typedef struct {
-    float   d;         // delta
-    float   s;         // d * sum(qs[i])
-    int8_t  qs[QK8_1]; // quants
+    float d;               // delta
+    float s;               // d * sum(qs[i])
+    int8_t  qs[QK8_1];     // quants
 } block_q8_1;
 static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
 
@@ -816,7 +837,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
         const float d  = max / -8;
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
 
         for (int j = 0; j < qk/2; ++j) {
             const float x0 = x[i*qk + 0    + j]*id;
@@ -856,8 +877,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
         const float d  = (max - min) / ((1 << 4) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
-        y[i].m = min;
+        y[i].d = GGML_FP32_TO_FP16(d);
+        y[i].m = GGML_FP32_TO_FP16(min);
 
         for (int j = 0; j < qk/2; ++j) {
             const float x0 = (x[i*qk + 0    + j] - min)*id;
@@ -988,7 +1009,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
         const float d = amax / ((1 << 7) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
 
         for (int j = 0; j < QK8_0; ++j) {
             const float x0 = x[i*QK8_0 + j]*id;
@@ -1023,7 +1044,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
         const float d = amax / ((1 << 7) - 1);
         const float id = d ? 1.0f/d : 0.0f;
 
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
 
         for (int j = 0; j < 8; j++) {
             const float32x4_t v  = vmulq_n_f32(srcv[j], id);
@@ -1058,7 +1079,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
 
         // Quantize these floats
         const float d = maxScalar / 127.f;
-        y[i].d = d;
+        y[i].d = GGML_FP32_TO_FP16(d);
         const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
         const __m256 mul = _mm256_set1_ps( id );
 
@@ -1157,7 +1178,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r
             sum += y[i].qs[QK8_1/2 + j];
         }
 
-        y[i].s = d * sum;
+        y[i].s = sum*d;
     }
 }
 
@@ -1309,7 +1330,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
     const int nb = k / qk;
 
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
 
         for (int j = 0; j < qk/2; ++j) {
             const int x0 = (x[i].qs[j] & 0x0F) - 8;
@@ -1329,8 +1350,8 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
     const int nb = k / qk;
 
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
-        const float m = x[i].m;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
+        const float m = GGML_FP16_TO_FP32(x[i].m);
 
         for (int j = 0; j < qk/2; ++j) {
             const int x0 = (x[i].qs[j] & 0x0F);
@@ -1405,7 +1426,7 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
     const block_q8_0 * restrict x = vx;
 
     for (int i = 0; i < nb; i++) {
-        const float d = x[i].d;
+        const float d = GGML_FP16_TO_FP32(x[i].d);
 
         for (int j = 0; j < qk; ++j) {
             y[i*qk + j] = x[i].qs[j]*d;
@@ -1669,8 +1690,9 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
 static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
     float tmp[8];
 
-    for (int i = 0; i < 8; i++)
+    for (int i = 0; i < 8; i++) {
         tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
 
     return _mm256_loadu_ps(tmp);
 }
@@ -2090,8 +2112,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const block_q8_0 * restrict y0 = &y[i + 0];
         const block_q8_0 * restrict y1 = &y[i + 1];
 
-        const uint8x16_t m4b   = vdupq_n_u8(0x0F);
-        const int8x16_t  s8b   = vdupq_n_s8(0x8);
+        const uint8x16_t m4b = vdupq_n_u8(0x0F);
+        const int8x16_t  s8b = vdupq_n_s8(0x8);
 
         const uint8x16_t v0_0 = vld1q_u8(x0->qs);
         const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2119,8 +2141,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h);
         const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h);
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
@@ -2137,8 +2159,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #endif
     }
 
@@ -2150,7 +2172,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; ++i) {
         /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
 
         __m256i bx = bytes_from_nibbles_32(x[i].qs);
 
@@ -2174,7 +2196,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; ++i) {
         // Compute combined scale for the block
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
 
         const __m128i lowMask = _mm_set1_epi8(0xF);
         const __m128i off = _mm_set1_epi8(8);
@@ -2216,7 +2238,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) );
 
         const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
 
@@ -2234,7 +2256,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) );
 
         const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
 
@@ -2267,7 +2289,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
+        const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) );
 
         const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
 
@@ -2285,7 +2307,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
         _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
+        const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) );
 
         const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
 
@@ -2333,7 +2355,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
             sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
         }
 
-        sumf += (x[i].d*y[i].d)*sumi;
+        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
     }
 
     *s = sumf;
@@ -2363,7 +2385,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const block_q8_1 * restrict y0 = &y[i + 0];
         const block_q8_1 * restrict y1 = &y[i + 1];
 
-        summs += x0->m * y0->s + x1->m * y1->s;
+        summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s;
 
         const uint8x16_t m4b = vdupq_n_u8(0x0F);
 
@@ -2387,8 +2409,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
         const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
@@ -2405,8 +2427,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #endif
     }
 
@@ -2419,13 +2441,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
 
     // Main loop
     for (int i = 0; i < nb; ++i) {
-        const float * d0 = &x[i].d;
-        const float * d1 = &y[i].d;
+        const float d0 = GGML_FP16_TO_FP32(x[i].d);
+        const float d1 = y[i].d;
 
-        summs += x[i].m * y[i].s;
+        summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
 
-        const __m256 d0v = _mm256_broadcast_ss( d0 );
-        const __m256 d1v = _mm256_broadcast_ss( d1 );
+        const __m256 d0v = _mm256_set1_ps( d0 );
+        const __m256 d1v = _mm256_set1_ps( d1 );
 
         // Compute combined scales
         const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
@@ -2434,7 +2456,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
         const __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
 
-        const __m256 xy = mul_sum_i8_pairs_float(bx, by);
+        const __m256 xy = mul_sum_us8_pairs_float(bx, by);
 
         // Accumulate d0*d1*x*y
 #if defined(__AVX2__)
@@ -2459,7 +2481,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
             sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]);
         }
 
-        sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s;
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
     }
 
     *s = sumf;
@@ -2535,16 +2557,13 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
         const int8x16_t v1_1l = vld1q_s8(y1->qs);
         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        const float x0d = GGML_FP16_TO_FP32(x0->d);
-        const float x1d = GGML_FP16_TO_FP32(x1->d);
-
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
-                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
-                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@@ -2561,8 +2580,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #endif
     }
 
@@ -2637,7 +2656,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; i++) {
         /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
 
         __m256i bx = bytes_from_nibbles_32(x[i].qs);
         __m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -2661,7 +2680,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; i++) {
         /* Compute combined scale for the block */
-        const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
 
         __m256i bx = bytes_from_nibbles_32(x[i].qs);
         const __m256i bxhi = bytes_from_bits_32(x[i].qh);
@@ -2704,7 +2723,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
             sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]);
         }
 
-        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi;
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
     }
 
     *s = sumf;
@@ -2786,16 +2805,13 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const int8x16_t v1_1l = vld1q_s8(y1->qs);
         const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
 
-        const float x0d = GGML_FP16_TO_FP32(x0->d);
-        const float x1d = GGML_FP16_TO_FP32(x1->d);
-
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
-                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d);
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
-                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
+                        vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #else
         const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
         const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
@@ -2812,8 +2828,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
         const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d);
 #endif
     }
 
@@ -2873,15 +2889,14 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
         const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);
 
-        const float x0d = GGML_FP16_TO_FP32(x0->d);
-
         // dot product
-        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
-                        wasm_i32x4_add(
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
-                                           wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
-                            wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
-                                           wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
+        sumv = wasm_f32x4_add(sumv,
+                wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
+                wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
+                               wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
+                wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
+                               wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
+                    wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d));
     }
 
     *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
@@ -2903,10 +2918,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10));
         bx = _mm256_or_si256(bx, bxhi);
 
-        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256 dy = _mm256_set1_ps(y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
 
         acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
     }
@@ -2937,10 +2952,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
         bxh = _mm_or_si128(bxh, bxhih);
         bx = _mm256_set_m128i(bxh, bxl);
 
-        const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+        const __m256 dy = _mm256_set1_ps(y[i].d);
         const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
-        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+        const __m256 q = mul_sum_us8_pairs_float(bx, by);
 
         acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
     }
@@ -3007,11 +3022,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
 #if defined(__ARM_FEATURE_DOTPROD)
         sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), x0_0, y0_0),
-                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d);
+                        vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
 
         sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
                         vdotq_s32(vdupq_n_s32(0), x1_0, y1_0),
-                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d);
+                        vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 
 #else
         const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0));
@@ -3029,8 +3044,8 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
         const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
         const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
 
-        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d);
-        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d);
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
 #endif
     }
 
@@ -3042,7 +3057,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
     // Main loop
     for (int i = 0; i < nb; ++i) {
         // Compute combined scale for the block
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
         __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
         __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
 
@@ -3068,7 +3083,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
             sumi += x[i].qs[j]*y[i].qs[j];
         }
 
-        sumf += (x[i].d*y[i].d)*sumi;
+        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
     }
 
     *s = sumf;
@@ -3457,6 +3472,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "ROPE",
     "ROPE_BACK",
     "ALIBI",
+    "CLAMP",
     "CONV_1D_1S",
     "CONV_1D_2S",
 
@@ -3467,7 +3483,8 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
     "MAP_BINARY",
 };
 
-static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
+static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
+
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -3517,6 +3534,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "rope(x)",
     "rope_back(x)",
     "alibi(x)",
+    "clamp(x)",
     "conv_1d_1s(x)",
     "conv_1d_2s(x)",
 
@@ -3527,7 +3545,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "f(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 50, "GGML_OP_COUNT != 50");
+static_assert(GGML_OP_COUNT == 51, "GGML_OP_COUNT != 51");
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
 static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3761,6 +3779,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g
         (t1->ne[3]%t0->ne[3] == 0);
 }
 
+static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
+}
+
 static inline int ggml_up32(int n) {
     return (n + 31) & ~31;
 }
@@ -4643,11 +4667,15 @@ struct ggml_tensor * ggml_mul_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
+    // TODO: support less-strict constraint
+    //       GGML_ASSERT(ggml_can_repeat(b, a));
+    GGML_ASSERT(ggml_can_repeat_rows(b, a));
 
     bool is_node = false;
 
     if (!inplace && (a->grad || b->grad)) {
+        // TODO: support backward pass for broadcasting
+        GGML_ASSERT(ggml_are_same_shape(a, b));
         is_node = true;
     }
 
@@ -6189,7 +6217,8 @@ struct ggml_tensor * ggml_alibi(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         int                   n_past,
-        int                   n_head) {
+        int                   n_head,
+        float                 bias_max) {
     GGML_ASSERT(n_past >= 0);
     bool is_node = false;
 
@@ -6208,6 +6237,8 @@ struct ggml_tensor * ggml_alibi(
 
     ((int32_t *) b->data)[0] = n_past;
     ((int32_t *) b->data)[1] = n_head;
+    GGML_ASSERT(sizeof(float) == sizeof(int32_t));
+    (((float *) b->data)[2]) = bias_max;
 
     ggml_scratch_load(ctx);
 
@@ -6219,6 +6250,40 @@ struct ggml_tensor * ggml_alibi(
     return result;
 }
 
+// ggml_clamp
+
+struct ggml_tensor * ggml_clamp(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 min,
+        float                 max) {
+    bool is_node = false;
+
+    if (a->grad) {
+        GGML_ASSERT(false); // TODO: implement backward
+        is_node = true;
+    }
+
+    // TODO: when implement backward, fix this:
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    ggml_scratch_save(ctx);
+
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
+
+    ((float *) b->data)[0] = min;
+    ((float *) b->data)[1] = max;
+
+    ggml_scratch_load(ctx);
+
+    result->op   = GGML_OP_CLAMP;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src0 = a;
+    result->src1 = b;
+
+    return result;
+}
+
 // ggml_conv_1d_1s
 
 struct ggml_tensor * ggml_conv_1d_1s(
@@ -7945,7 +8010,7 @@ static void ggml_compute_forward_mul_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -7953,10 +8018,25 @@ static void ggml_compute_forward_mul_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nr  = ggml_nrows(src0);
-    const int64_t ne0 = src0->ne[0];
-    const int64_t ne1 = src0->ne[1];
-    const int64_t ne2 = src0->ne[2];
+#ifdef GGML_USE_CUBLAS
+    if (src1->backend == GGML_BACKEND_CUDA) {
+        if (ith == 0) {
+            ggml_cuda_mul(src0, src1, dst);
+        }
+        return;
+    }
+#endif
+
+    const int64_t nr = ggml_nrows(src0);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
 
     const size_t nb00 = src0->nb[0];
     const size_t nb01 = src0->nb[1];
@@ -7975,44 +8055,51 @@ static void ggml_compute_forward_mul_f32(
 
     GGML_ASSERT( nb0 == sizeof(float));
     GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(ne00 == ne10);
 
     if (nb10 == sizeof(float)) {
-        for (int ir = ith; ir < nr; ir += nth) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
 
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
 #ifdef GGML_USE_ACCELERATE
             UNUSED(ggml_vec_mul_f32);
 
-            vDSP_vmul(
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                    ne0);
+            vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr,  1, ne00);
 #else
-            ggml_vec_mul_f32(ne0,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+            ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
 #endif
                 // }
             // }
         }
     } else {
         // src1 is not contiguous
-        for (int ir = ith; ir < nr; ir += nth) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+        for (int64_t ir = ith; ir < nr; ir += nth) {
+            // src0 and dst are same shape => same indices
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
-            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
-            for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+            for (int64_t i0 = 0; i0 < ne00; i0++) {
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
             }
@@ -10501,34 +10588,29 @@ static void ggml_compute_forward_diag_mask_f32(
     assert(src1->type == GGML_TYPE_I32);
     assert(ggml_nelements(src1) == 2);
 
+    const int ith = params->ith;
+    const int nth = params->nth;
+
     const int  n_past  =       ((int32_t *) src1->data)[0];
     const bool inplace = (bool)((int32_t *) src1->data)[1];
 
-    if (params->type == GGML_TASK_INIT) {
-        // TODO: this hack is not good, need a better way to handle this
-        if (!inplace) {
-            // use the init task to copy src -> dst
-            struct ggml_compute_params params_cpy = *params;
-
-            params_cpy.ith  = 0;
-            params_cpy.nth  = 1;
-            params_cpy.type = GGML_TASK_COMPUTE;
-
-            ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
-        }
+    assert(n_past >= 0);
 
-        return;
+    if (!inplace && (params->type == GGML_TASK_INIT)) {
+        // memcpy needs to be synchronized across threads to avoid race conditions.
+        // => do it in INIT phase
+        GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+        GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+        memcpy(
+            ((char *)  dst->data),
+            ((char *) src0->data),
+            ggml_nbytes(dst));
     }
 
-    if (params->type == GGML_TASK_FINALIZE) {
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    assert(n_past >= 0);
-
     // TODO: handle transposed/permuted matrices
 
     const int n  = ggml_nrows(src0);
@@ -10682,14 +10764,15 @@ static void ggml_compute_forward_alibi_f32(
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
     assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 2);
+    assert(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_head = ((int32_t *) src1->data)[1];
+    const int   n_past   = ((int32_t *) src1->data)[0];
+    const int   n_head   = ((int32_t *) src1->data)[1];
+    const float max_bias = ((float *)   src1->data)[2];
 
     assert(n_past >= 0);
 
@@ -10712,8 +10795,8 @@ static void ggml_compute_forward_alibi_f32(
     // add alibi to src0 (KQ_scaled)
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
 
-    const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
-    const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
+    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
     for (int i = 0; i < ne0; i++) {
         for (int j = 0; j < ne1; j++) {
@@ -10731,13 +10814,13 @@ static void ggml_compute_forward_alibi_f32(
                     m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
                 }
 
-                pdst[0] = i * m_k + src[0];
+                pdst[0] = (i-ne0+1) * m_k + src[0];
+
             }
         }
     }
 }
 
-
 static void ggml_compute_forward_alibi_f16(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
@@ -10745,14 +10828,15 @@ static void ggml_compute_forward_alibi_f16(
         struct ggml_tensor * dst) {
     assert(params->ith == 0);
     assert(src1->type == GGML_TYPE_I32);
-    assert(ggml_nelements(src1) == 2);
+    assert(ggml_nelements(src1) == 3);
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
-    const int n_past = ((int32_t *) src1->data)[0];
-    const int n_head = ((int32_t *) src1->data)[1];
+    const int   n_past   = ((int32_t *) src1->data)[0];
+    const int   n_head   = ((int32_t *) src1->data)[1];
+    const float max_bias = ((float *)   src1->data)[2];
 
     assert(n_past >= 0);
 
@@ -10775,8 +10859,8 @@ static void ggml_compute_forward_alibi_f16(
     // add alibi to src0 (KQ_scaled)
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
 
-    const float m0 = powf(2.0f, -8.0f / n_heads_log2_floor);
-    const float m1 = powf(2.0f, -4.0f / n_heads_log2_floor);
+    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
     for (int i = 0; i < ne0; i++) {
         for (int j = 0; j < ne1; j++) {
@@ -10795,7 +10879,7 @@ static void ggml_compute_forward_alibi_f16(
                 }
 
                 // we return F32
-                pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
+                pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
             }
         }
     }
@@ -10831,6 +10915,77 @@ static void ggml_compute_forward_alibi(
     }
 }
 
+
+// ggml_compute_forward_clamp
+
+static void ggml_compute_forward_clamp_f32(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    assert(params->ith == 0);
+    assert(src1->type == GGML_TYPE_I32);
+    assert(ggml_nelements(src1) == 2);
+
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return;
+    }
+
+    const int min = ((float *) src1->data)[0];
+    const int max = ((float *) src1->data)[1];
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    const size_t nb00 = src0->nb[0];
+    const size_t nb01 = src0->nb[1];
+
+    const size_t nb0 = dst->nb[0];
+    const size_t nb1 = dst->nb[1];
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb00 == sizeof(float));
+
+    for (int j = ith; j < n; j += nth) {
+        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
+        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
+
+        for (int i = 0; i < nc; i++) {
+            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
+        }
+    }
+}
+
+static void ggml_compute_forward_clamp(
+        const struct ggml_compute_params * params,
+        const struct ggml_tensor * src0,
+        const struct ggml_tensor * src1,
+        struct ggml_tensor * dst) {
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_clamp_f32(params, src0, src1, dst);
+            } break;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q8_1:
+        case GGML_TYPE_I8:
+        case GGML_TYPE_I16:
+        case GGML_TYPE_I32:
+        case GGML_TYPE_COUNT:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_rope
 
 static void ggml_compute_forward_rope_f32(
@@ -12812,6 +12967,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
             } break;
+        case GGML_OP_CLAMP:
+            {
+                ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
+            } break;
         case GGML_OP_CONV_1D_1S:
             {
                 ggml_compute_forward_conv_1d_1s(params, tensor->src0, tensor->src1, tensor);
@@ -13119,6 +13278,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 GGML_ASSERT(false); // TODO: not implemented
             } break;
+        case GGML_OP_CLAMP:
+            {
+                GGML_ASSERT(false); // TODO: not implemented
+            } break;
         case GGML_OP_SILU:
             {
                 // necessary for llama
@@ -13998,6 +14161,10 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                     {
                         node->n_tasks = 1; //TODO
                     } break;
+                case GGML_OP_CLAMP:
+                    {
+                        node->n_tasks = 1; //TODO
+                    } break;
                 case GGML_OP_CONV_1D_1S:
                 case GGML_OP_CONV_1D_2S:
                     {
diff --git a/ggml.h b/ggml.h
index 255541d0257e3abea7461e6c8a84a4cb5f1fe583..51a616c501bb3eb46bf2b20c727b0c0a0b7a16dc 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_FILE_MAGIC   0x67676d6c // "ggml"
 #define GGML_FILE_VERSION 1
 
-#define GGML_QNT_VERSION        1    // bump this on quantization format changes
+#define GGML_QNT_VERSION        2    // bump this on quantization format changes
 #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
 
 #define GGML_MAX_DIMS          4
@@ -313,6 +313,7 @@ extern "C" {
         GGML_OP_ROPE,
         GGML_OP_ROPE_BACK,
         GGML_OP_ALIBI,
+        GGML_OP_CLAMP,
         GGML_OP_CONV_1D_1S,
         GGML_OP_CONV_1D_2S,
 
@@ -849,7 +850,7 @@ extern "C" {
             int                   n_past);
 
     // in-place, returns view(a)
-    GGML_API struct ggml_tensor * gml_diag_mask_zero_inplace(
+    GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   n_past);
@@ -897,7 +898,16 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   n_past,
-            int                   n_head);
+            int                   n_head,
+            float                 bias_max);
+
+    // clamp
+    // in-place, returns view(a)
+    struct ggml_tensor * ggml_clamp(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float                 min,
+            float                 max);
 
     // padding = 1
     // TODO: we don't support extra parameters for now