]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : sync (abort callback, mul / add broadcast, fix alibi) (#2183)
authorGeorgi Gerganov <redacted>
Tue, 11 Jul 2023 19:53:34 +0000 (22:53 +0300)
committerGitHub <redacted>
Tue, 11 Jul 2023 19:53:34 +0000 (22:53 +0300)
ggml-cuda.cu
ggml.c
ggml.h
tests/test-grad0.c
tests/test-opt.c

index 1673e7e4c9ef4263b11f89499925dbeed305b643..2fb30c6e6087ce560c56cc3b262eb36bb7ff2df2 100644 (file)
@@ -239,13 +239,13 @@ struct ggml_tensor_extra_gpu {
     cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
 };
 
-static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
+static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i >= k) {
+    if (i >= kx) {
         return;
     }
-    dst[i] = x[i] + y[i];
+    dst[i] = x[i] + y[i%ky];
 }
 
 static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@@ -275,16 +275,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    const float eps = 1e-5f;
+
+    float mean = 0.0f;
+    float var = 0.0f;
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        const float xi = x[row*ncols + col];
+        mean += xi;
+        var += xi * xi;
+    }
+
+    // sum up partial sums
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
+        var += __shfl_xor_sync(0xffffffff, var, mask, 32);
+    }
+
+    mean /= ncols;
+    var = var / ncols - mean * mean;
+    const float inv_var = rsqrtf(var + eps);
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
+    }
+}
+
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
-    const float eps = 1e-6;
+    const float eps = 1e-6f;
 
     float tmp = 0.0f; // partial sum for thread in warp
 
-    for (int i = 0; i < ncols; i += WARP_SIZE) {
-        const int col = i + tid;
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
         const float xi = x[row*ncols + col];
         tmp += xi * xi;
     }
@@ -296,10 +326,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
     }
 
     const float mean = tmp / ncols;
-    const float scale = 1.0f / sqrtf(mean + eps);
+    const float scale = rsqrtf(mean + eps);
 
-    for (int i = 0; i < ncols; i += WARP_SIZE) {
-        const int col = i + tid;
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
         dst[row*ncols + col] = scale * x[row*ncols + col];
     }
 }
@@ -1689,9 +1718,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
     dst[i] = scale * x[i];
 }
 
-static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
-    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
+static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
+    const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
+    add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
 }
 
 static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@@ -1709,6 +1738,12 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
 static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     const dim3 block_dims(WARP_SIZE, 1, 1);
@@ -2239,14 +2274,16 @@ inline void ggml_cuda_op_add(
     GGML_ASSERT(src1_ddf_i != nullptr);
     GGML_ASSERT(dst_ddf_i != nullptr);
 
-    const int64_t ne0 = src0->ne[0];
+    const int64_t ne00 = src0->ne[0];
     const int64_t i01_diff = i01_high - i01_low;
 
+    const int64_t ne10 = src1->ne[0];
+
     // compute
     if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+        add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
     } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-        add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
+        add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
     } else {
         GGML_ASSERT(false);
     }
@@ -2268,20 +2305,11 @@ inline void ggml_cuda_op_mul(
     GGML_ASSERT(dst_ddf_i != nullptr);
 
     const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
 
     const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-
-    for (int64_t i01 = i01_low; i01 < i01_high; i01++) {
-        const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
 
-        float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
-        float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
-        float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
-
-        // compute
-        mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
-    }
+    mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10, cudaStream_main);
 
     (void) dst;
     (void) src0_ddq_i;
@@ -2310,6 +2338,28 @@ inline void ggml_cuda_op_silu(
     (void) i1;
 }
 
+inline void ggml_cuda_op_norm(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
+
 inline void ggml_cuda_op_rms_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2930,6 +2980,11 @@ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
 }
 
+void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
+}
+
 void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
@@ -3160,7 +3215,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         }
 
 
-        cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+        CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
 
         extra->data_device[id] = buf;
 
@@ -3322,6 +3377,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_silu;
             break;
+        case GGML_OP_NORM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_norm;
+            break;
         case GGML_OP_RMS_NORM:
             if (!any_on_device) {
                 return false;
diff --git a/ggml.c b/ggml.c
index 8dc30a372e1ae45a5efd80d84b51015100eba198..793ff709508a2b325237bb9276fecbf0d3349621 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -25,6 +25,7 @@
 #include <float.h>
 #include <limits.h>
 #include <stdarg.h>
+#include <signal.h>
 
 #ifdef GGML_USE_METAL
 #include <unistd.h>
 typedef volatile LONG atomic_int;
 typedef atomic_int atomic_bool;
 
-static void atomic_store(atomic_int* ptr, LONG val) {
+static void atomic_store(atomic_int * ptr, LONG val) {
     InterlockedExchange(ptr, val);
 }
-static LONG atomic_load(atomic_int* ptr) {
+static LONG atomic_load(atomic_int * ptr) {
     return InterlockedCompareExchange(ptr, 0, 0);
 }
-static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
+static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
     return InterlockedExchangeAdd(ptr, inc);
 }
-static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
+static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
     return atomic_fetch_add(ptr, -(dec));
 }
 
 typedef HANDLE pthread_t;
 
 typedef DWORD thread_ret_t;
-static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
+static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
     (void) unused;
     HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
     if (handle == NULL)
@@ -77,7 +78,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
     return 0;
 }
 
-static int pthread_join(pthread_t thread, void* unused) {
+static int pthread_join(pthread_t thread, void * unused) {
     (void) unused;
     return (int) WaitForSingleObject(thread, INFINITE);
 }
@@ -90,7 +91,7 @@ static int sched_yield (void) {
 #include <pthread.h>
 #include <stdatomic.h>
 
-typedef void* thread_ret_t;
+typedef void * thread_ret_t;
 
 #include <sys/types.h>
 #include <sys/stat.h>
@@ -4723,7 +4724,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
             {
                 assert(tensor->nb[0] == sizeof(ggml_fp16_t));
                 for (int i = 0; i < n; i++) {
-                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
                 }
             } break;
         case GGML_TYPE_F32:
@@ -4775,7 +4776,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
             {
                 assert(tensor->nb[0] == sizeof(ggml_fp16_t));
                 for (int i = 0; i < n; i++) {
-                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+                    ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
                 }
             } break;
         case GGML_TYPE_F32:
@@ -5035,11 +5036,15 @@ struct ggml_tensor * ggml_add_impl(
         struct ggml_tensor * a,
         struct ggml_tensor * b,
         bool inplace) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
+    // TODO: support less-strict constraint
+    //       GGML_ASSERT(ggml_can_repeat(b, a));
+    GGML_ASSERT(ggml_can_repeat_rows(b, a));
 
     bool is_node = false;
 
-    if (a->grad || b->grad) {
+    if (!inplace && (a->grad || b->grad)) {
+        // TODO: support backward pass for broadcasting
+        GGML_ASSERT(ggml_are_same_shape(a, b));
         is_node = true;
     }
 
@@ -8297,7 +8302,7 @@ static void ggml_compute_forward_add_f32(
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
@@ -8322,23 +8327,23 @@ static void ggml_compute_forward_add_f32(
 
     if (nb10 == sizeof(float)) {
         for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+            float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
 
 #ifdef GGML_USE_ACCELERATE
-            vDSP_vadd(
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
-                    ne0);
+            vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
 #else
-            ggml_vec_add_f32(ne0,
-                    (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
-                    (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
-                    (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+            ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
 #endif
                 // }
             // }
@@ -8346,15 +8351,20 @@ static void ggml_compute_forward_add_f32(
     } else {
         // src1 is not contiguous
         for (int ir = ir0; ir < ir1; ++ir) {
-            // src0, src1 and dst are same shape => same indices
-            const int i3 = ir/(ne2*ne1);
-            const int i2 = (ir - i3*ne2*ne1)/ne1;
-            const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+            // src1 is broadcastable across src0 and dst in i1, i2, i3
+            const int64_t i03 = ir/(ne02*ne01);
+            const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+            const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+            const int64_t i13 = i03 % ne13;
+            const int64_t i12 = i02 % ne12;
+            const int64_t i11 = i01 % ne11;
+
+            float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+            float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
 
-            float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
-            float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
             for (int i0 = 0; i0 < ne0; i0++) {
-                float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+                float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
 
                 dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
             }
@@ -11717,7 +11727,7 @@ static void ggml_compute_forward_alibi_f32(
 
     const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int ne1 = src0->ne[1]; // seq_len_without_past
-    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    const int ne2 = src0->ne[2]; // n_head -> this is k
     //const int ne3 = src0->ne[3]; // 1 -> bsz
 
     const int n  = ggml_nrows(src0);
@@ -11728,8 +11738,9 @@ static void ggml_compute_forward_alibi_f32(
     const int nb2 = src0->nb[2];
     //const int nb3 = src0->nb[3];
 
-    assert(nb0 == sizeof(float));
-    assert(ne1 + n_past == ne0); (void) n_past;
+    GGML_ASSERT(nb0 == sizeof(float));
+    GGML_ASSERT(ne1 + n_past == ne0);
+    GGML_ASSERT(n_head == ne2);
 
     // add alibi to src0 (KQ_scaled)
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11753,7 +11764,7 @@ static void ggml_compute_forward_alibi_f32(
                     m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
                 }
 
-                pdst[0] = (i-ne0+1) * m_k + src[0];
+                pdst[0] = i * m_k + src[0];
 
             }
         }
@@ -11782,7 +11793,7 @@ static void ggml_compute_forward_alibi_f16(
 
     const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
     const int ne1 = src0->ne[1]; // seq_len_without_past
-    //const int ne2 = src0->ne[2]; // n_head -> this is k
+    const int ne2 = src0->ne[2]; // n_head -> this is k
     //const int ne3 = src0->ne[3]; // 1 -> bsz
 
     const int n  = ggml_nrows(src0);
@@ -11793,8 +11804,9 @@ static void ggml_compute_forward_alibi_f16(
     const int nb2 = src0->nb[2];
     //const int nb3 = src0->nb[3];
 
-    assert(nb0 == sizeof(ggml_fp16_t));
-    assert(ne1 + n_past == ne0); (void) n_past;
+    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
+    GGML_ASSERT(n_head == ne2);
 
     // add alibi to src0 (KQ_scaled)
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11819,7 +11831,7 @@ static void ggml_compute_forward_alibi_f16(
                 }
 
                 // we return F32
-                pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
+                pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
             }
         }
     }
@@ -15944,6 +15956,9 @@ struct ggml_compute_state_shared {
     // synchronization primitives
     atomic_int n_active; // num active threads
     atomic_int node_n;   // active graph node
+
+    bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
+    void * abort_callback_data;
 };
 
 struct ggml_compute_state {
@@ -15975,6 +15990,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
     int node_n = -1;
 
     while (true) {
+        if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
+            state->shared->node_n += 1;
+            return (thread_ret_t) GGML_EXIT_ABORTED;
+        }
         if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
             // all other threads are finished and spinning
             // do finalize and init here so we don't have synchronize again
@@ -16028,6 +16047,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
                 } else {
                     break;
                 }
+
+                if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
+                    break;
+                }
             }
 
             atomic_store(&state->shared->n_active, n_threads);
@@ -16061,7 +16084,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         }
     }
 
-    return 0;
+    return GGML_EXIT_SUCCESS;
 }
 
 struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
@@ -16401,7 +16424,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
     return cplan;
 }
 
-void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
+int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
     {
         GGML_ASSERT(cplan);
         GGML_ASSERT(cplan->n_threads > 0);
@@ -16427,6 +16450,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
         /*.n_threads               =*/ n_threads,
         /*.n_active                =*/ n_threads,
         /*.node_n                  =*/ -1,
+        /*.abort_callback          =*/ NULL,
+        /*.abort_callback_data     =*/ NULL,
     };
     struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
 
@@ -16450,12 +16475,12 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
     const int64_t perf_start_time_us = ggml_perf_time_us();
 
     // this is a work thread too
-    ggml_graph_compute_thread(&workers[0]);
+    int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
 
     // don't leave affinity set on the main thread
     clear_numa_thread_affinity();
 
-    // join thread pool
+    // join or kill thread pool
     if (n_threads > 1) {
         for (int j = 1; j < n_threads; j++) {
             const int rc = ggml_thread_join(workers[j].thrd, NULL);
@@ -16479,6 +16504,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
                 (double) perf_time_us_cur     / 1000.0,
                 (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
     }
+
+    return compute_status;
 }
 
 void ggml_graph_reset(struct ggml_cgraph * cgraph) {
diff --git a/ggml.h b/ggml.h
index d7c9e0f0e06815058ed61ac0128c61fecde5eb8a..8fe05d3a595b7b9368711cb453d1888ef65faa20 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 #define GGML_MAX_NAME          48
 #define GGML_DEFAULT_N_THREADS 4
 
+
+#define GGML_EXIT_SUCCESS 0
+#define GGML_EXIT_ABORTED 1
+
 #define GGML_UNUSED(x) (void)(x)
 
+
 #define GGML_ASSERT(x) \
     do { \
         if (!(x)) { \
@@ -442,6 +447,10 @@ extern "C" {
 
         // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
         int n_tasks[GGML_MAX_NODES];
+
+        // abort ggml_graph_compute when true
+        bool (*abort_callback)(void * data);
+        void * abort_callback_data;
     };
 
     // computation graph
@@ -1303,7 +1312,7 @@ extern "C" {
     // ggml_graph_plan() has to be called before ggml_graph_compute()
     // when plan.work_size > 0, caller must allocate memory for plan.work_data
     GGML_API struct ggml_cplan ggml_graph_plan   (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
-    GGML_API              void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+    GGML_API               int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
     GGML_API              void ggml_graph_reset  (struct ggml_cgraph * cgraph);
 
     // same as ggml_graph_compute() but the work data is allocated as a part of the context
index da4001ce5269fbac5c9fa720396954f996662d9b..01467bc184372e5c36c069c22a6791d4ec6f2163 100644 (file)
@@ -10,7 +10,9 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
+#if defined(__GNUC__)
 #pragma GCC diagnostic ignored "-Wdouble-promotion"
+#endif
 
 #define MAX_NARGS 3
 
index e928a7df7ee68e60a9af80c46deb785d0f2a7383..5531814c48c997e1d2e0d5e563ffa9235877fbf1 100644 (file)
@@ -7,7 +7,9 @@
 
 #define MAX_NARGS 2
 
+#if defined(__GNUC__)
 #pragma GCC diagnostic ignored "-Wdouble-promotion"
+#endif
 
 //
 // logging