]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : always check bounds on get_rows operations (#9354)
authorslaren <redacted>
Sat, 7 Sep 2024 18:23:07 +0000 (20:23 +0200)
committerGitHub <redacted>
Sat, 7 Sep 2024 18:23:07 +0000 (20:23 +0200)
ggml/src/ggml.c

index 9dc12a02079211613250b1f72851ec8ea5d18182..c25f9caa38500eb16d13c8ad0644cfe9683c7a73 100644 (file)
@@ -13721,7 +13721,7 @@ static void ggml_compute_forward_get_rows_q(
         const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
         const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        assert(i01 >= 0 && i01 < ne01);
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 
         dequantize_row_q(
                 (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13762,7 +13762,7 @@ static void ggml_compute_forward_get_rows_f16(
         const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
         const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        assert(i01 >= 0 && i01 < ne01);
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 
         ggml_fp16_to_fp32_row(
                 (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13803,7 +13803,7 @@ static void ggml_compute_forward_get_rows_bf16(
         const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
         const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        assert(i01 >= 0 && i01 < ne01);
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 
         ggml_bf16_to_fp32_row(
                 (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
@@ -13844,7 +13844,7 @@ static void ggml_compute_forward_get_rows_f32(
         const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
         const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
-        assert(i01 >= 0 && i01 < ne01);
+        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 
         ggml_vec_cpy_f32(nc,
                 (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),