]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : always check bounds on get_rows operations (llama/9354)
authorslaren <redacted>
Sat, 7 Sep 2024 18:23:07 +0000 (20:23 +0200)
committerGeorgi Gerganov <redacted>
Sun, 8 Sep 2024 11:43:07 +0000 (14:43 +0300)
src/ggml.c

index 212be9fe597ef8647ac35e07247353fb47b56f02..28ee46e042bbce644505a77114ad3440cef08ad4 100644 (file)
@@ -13709,7 +13709,7 @@ static void ggml_compute_forward_get_rows_q(
         const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
         const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        assert(i01 >= 0 && i01 < ne01);
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 
         dequantize_row_q(
                 (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13750,7 +13750,7 @@ static void ggml_compute_forward_get_rows_f16(
         const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
         const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        assert(i01 >= 0 && i01 < ne01);
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 
         ggml_fp16_to_fp32_row(
                 (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13791,7 +13791,7 @@ static void ggml_compute_forward_get_rows_bf16(
         const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
         const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        assert(i01 >= 0 && i01 < ne01);
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 
         ggml_bf16_to_fp32_row(
                 (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13832,7 +13832,7 @@ static void ggml_compute_forward_get_rows_f32(
         const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
         const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        assert(i01 >= 0 && i01 < ne01);
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 
         ggml_vec_cpy_f32(nc,
                 (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),