]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CLBlast: Fix temporary buffer size for f16 conversion (wsize)
authorshibe2 <redacted>
Wed, 11 Oct 2023 17:30:06 +0000 (21:30 +0400)
committershibe2 <redacted>
Tue, 17 Oct 2023 17:02:30 +0000 (21:02 +0400)
Fix buffer overflow.
Reduce the size to fit just one 2D slice.
Assert sufficient size.

ggml-opencl.cpp

index 33d0691eb74ca7f59bb3ddf43e9a8f6702b5e91d..22fd0e3a77fce21dcb2c3907369746e1f88ab4dd 100644 (file)
@@ -1568,7 +1568,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
     ggml_cl_pool_free(d_D, d_size);
 }
 
-static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
+static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
     GGML_ASSERT(fp16_support);
 
     const int64_t ne00 = src0->ne[0];
@@ -1598,6 +1598,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
     const int y_ne = ne11 * ne10;
     const int d_ne = ne11 * ne01;
 
+    GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne);
+    GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne);
+    ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata;
+
     size_t x_size;
     size_t y_size;
     size_t d_size;
@@ -1634,7 +1638,6 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
 
             // convert src1 to fp16
             // TODO: use multiple threads
-            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
             char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
             if (src1_cont_rows) {
                 if (src1_cont_cols) {
@@ -1897,8 +1900,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor *
 }
 
 size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
-    if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
-        return ggml_nelements(src1) * sizeof(ggml_fp16_t);
+    if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
+        return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]);
     }
     return 0;
 }