]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : export symbols (#1155)
authorGeorgi Gerganov <redacted>
Mon, 24 Apr 2023 19:18:25 +0000 (22:18 +0300)
committerGitHub <redacted>
Mon, 24 Apr 2023 19:18:25 +0000 (22:18 +0300)
ggml.h

diff --git a/ggml.h b/ggml.h
index 460d4ffe03d85adb427cb916369186edb6ba3b54..27589078181821d8025440fef9457ee4171a303b 100644 (file)
--- a/ggml.h
+++ b/ggml.h
 //
 //
 
-#ifdef  __cplusplus
-extern "C" {
+#ifdef GGML_SHARED
+#    if defined(_WIN32) && !defined(__MINGW32__)
+#        ifdef GGML_BUILD
+#            define GGML_API __declspec(dllexport)
+#        else
+#            define GGML_API __declspec(dllimport)
+#        endif
+#    else
+#        define GGML_API __attribute__ ((visibility ("default")))
+#    endif
+#else
+#    define GGML_API
 #endif
 
 #include <stdint.h>
 #include <stddef.h>
 #include <stdbool.h>
 
+#define GGML_FILE_MAGIC   0x67676d6c // "ggml"
+#define GGML_FILE_VERSION 1
+
 #define GGML_MAX_DIMS          4
 #define GGML_MAX_NODES         4096
 #define GGML_MAX_PARAMS        16
@@ -184,682 +197,688 @@ extern "C" {
 #define GGML_MAX_OPT           4
 #define GGML_DEFAULT_N_THREADS 4
 
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
 #ifdef __ARM_NEON
-// we use the built-in 16-bit float type
-typedef __fp16 ggml_fp16_t;
+    // we use the built-in 16-bit float type
+    typedef __fp16 ggml_fp16_t;
 #else
-typedef uint16_t ggml_fp16_t;
+    typedef uint16_t ggml_fp16_t;
 #endif
 
-// convert FP16 <-> FP32
-float       ggml_fp16_to_fp32(ggml_fp16_t x);
-ggml_fp16_t ggml_fp32_to_fp16(float x);
-
-struct ggml_object;
-struct ggml_context;
-
-enum ggml_type {
-    // explicitly numbered values are used in llama.cpp files
-    GGML_TYPE_F32  = 0,
-    GGML_TYPE_F16  = 1,
-    GGML_TYPE_Q4_0 = 2,
-    GGML_TYPE_Q4_1 = 3,
-    GGML_TYPE_Q4_2 = 4,
-    GGML_TYPE_Q4_3 = 5,
-    GGML_TYPE_Q8_0 = 6,
-    GGML_TYPE_I8,
-    GGML_TYPE_I16,
-    GGML_TYPE_I32,
-    GGML_TYPE_COUNT,
-};
-
-// available tensor operations:
-enum ggml_op {
-    GGML_OP_NONE = 0,
-
-    GGML_OP_DUP,
-    GGML_OP_ADD,
-    GGML_OP_SUB,
-    GGML_OP_MUL,
-    GGML_OP_DIV,
-    GGML_OP_SQR,
-    GGML_OP_SQRT,
-    GGML_OP_SUM,
-    GGML_OP_MEAN,
-    GGML_OP_REPEAT,
-    GGML_OP_ABS,
-    GGML_OP_SGN,
-    GGML_OP_NEG,
-    GGML_OP_STEP,
-    GGML_OP_RELU,
-    GGML_OP_GELU,
-    GGML_OP_SILU,
-    GGML_OP_NORM, // normalize
-    GGML_OP_RMS_NORM,
-
-    GGML_OP_MUL_MAT,
-
-    GGML_OP_SCALE,
-    GGML_OP_CPY,
-    GGML_OP_CONT,
-    GGML_OP_RESHAPE,
-    GGML_OP_VIEW,
-    GGML_OP_PERMUTE,
-    GGML_OP_TRANSPOSE,
-    GGML_OP_GET_ROWS,
-    GGML_OP_DIAG_MASK_INF,
-    GGML_OP_SOFT_MAX,
-    GGML_OP_ROPE,
-    GGML_OP_CONV_1D_1S,
-    GGML_OP_CONV_1D_2S,
-
-    GGML_OP_FLASH_ATTN,
-    GGML_OP_FLASH_FF,
-
-    GGML_OP_MAP_UNARY,
-    GGML_OP_MAP_BINARY,
-
-    GGML_OP_COUNT,
-};
-
-
-// ggml object
-struct ggml_object {
-    size_t offs;
-    size_t size;
-
-    struct ggml_object * next;
-
-    char padding[8];
-};
-
-static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
-
-// n-dimensional tensor
-struct ggml_tensor {
-    enum ggml_type type;
-
-    int    n_dims;
-    int64_t ne[GGML_MAX_DIMS]; // number of elements
-    size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
-                               // nb[0] = sizeof(type)
-                               // nb[1] = nb[0]   * ne[0] + padding
-                               // nb[i] = nb[i-1] * ne[i-1]
-
-    // compute data
-    enum ggml_op op;
-
-    bool is_param;
-
-    struct ggml_tensor * grad;
-    struct ggml_tensor * src0;
-    struct ggml_tensor * src1;
-    struct ggml_tensor * opt[GGML_MAX_OPT];
-
-    // thread scheduling
-    int n_tasks;
-
-    // performance
-    int     perf_runs;
-    int64_t perf_cycles;
-    int64_t perf_time_us;
-
-    void * data;
-    char padding[8];
-};
-
-// computation graph
-struct ggml_cgraph {
-    int n_nodes;
-    int n_leafs;
-    int n_threads;
-
-    size_t work_size;
-    struct ggml_tensor * work;
-
-    struct ggml_tensor * nodes[GGML_MAX_NODES];
-    struct ggml_tensor * grads[GGML_MAX_NODES];
-    struct ggml_tensor * leafs[GGML_MAX_NODES];
-
-    // performance
-    int     perf_runs;
-    int64_t perf_cycles;
-    int64_t perf_time_us;
-};
-
-// scratch buffer
-struct ggml_scratch {
-    size_t offs;
-    size_t size;
-    void * data;
-};
+    // convert FP16 <-> FP32
+    GGML_API float       ggml_fp16_to_fp32(ggml_fp16_t x);
+    GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
+
+    struct ggml_object;
+    struct ggml_context;
+
+    enum ggml_type {
+        GGML_TYPE_F32  = 0,
+        GGML_TYPE_F16  = 1,
+        GGML_TYPE_Q4_0 = 2,
+        GGML_TYPE_Q4_1 = 3,
+        GGML_TYPE_Q4_2 = 4,
+        GGML_TYPE_Q4_3 = 5,
+        GGML_TYPE_Q8_0 = 6,
+        GGML_TYPE_I8,
+        GGML_TYPE_I16,
+        GGML_TYPE_I32,
+        GGML_TYPE_COUNT,
+    };
+
+    // available tensor operations:
+    enum ggml_op {
+        GGML_OP_NONE = 0,
+
+        GGML_OP_DUP,
+        GGML_OP_ADD,
+        GGML_OP_SUB,
+        GGML_OP_MUL,
+        GGML_OP_DIV,
+        GGML_OP_SQR,
+        GGML_OP_SQRT,
+        GGML_OP_SUM,
+        GGML_OP_MEAN,
+        GGML_OP_REPEAT,
+        GGML_OP_ABS,
+        GGML_OP_SGN,
+        GGML_OP_NEG,
+        GGML_OP_STEP,
+        GGML_OP_RELU,
+        GGML_OP_GELU,
+        GGML_OP_SILU,
+        GGML_OP_NORM, // normalize
+        GGML_OP_RMS_NORM,
+
+        GGML_OP_MUL_MAT,
+
+        GGML_OP_SCALE,
+        GGML_OP_CPY,
+        GGML_OP_CONT,
+        GGML_OP_RESHAPE,
+        GGML_OP_VIEW,
+        GGML_OP_PERMUTE,
+        GGML_OP_TRANSPOSE,
+        GGML_OP_GET_ROWS,
+        GGML_OP_DIAG_MASK_INF,
+        GGML_OP_SOFT_MAX,
+        GGML_OP_ROPE,
+        GGML_OP_CONV_1D_1S,
+        GGML_OP_CONV_1D_2S,
+
+        GGML_OP_FLASH_ATTN,
+        GGML_OP_FLASH_FF,
+
+        GGML_OP_MAP_UNARY,
+        GGML_OP_MAP_BINARY,
+
+        GGML_OP_COUNT,
+    };
+
+
+    // ggml object
+    struct ggml_object {
+        size_t offs;
+        size_t size;
+
+        struct ggml_object * next;
+
+        char padding[8];
+    };
+
+    static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
+
+    // n-dimensional tensor
+    struct ggml_tensor {
+        enum ggml_type type;
+
+        int     n_dims;
+        int64_t ne[GGML_MAX_DIMS]; // number of elements
+        size_t  nb[GGML_MAX_DIMS]; // stride in bytes:
+                                   // nb[0] = sizeof(type)
+                                   // nb[1] = nb[0]   * ne[0] + padding
+                                   // nb[i] = nb[i-1] * ne[i-1]
+
+        // compute data
+        enum ggml_op op;
+
+        bool is_param;
+
+        struct ggml_tensor * grad;
+        struct ggml_tensor * src0;
+        struct ggml_tensor * src1;
+        struct ggml_tensor * opt[GGML_MAX_OPT];
+
+        // thread scheduling
+        int n_tasks;
+
+        // performance
+        int     perf_runs;
+        int64_t perf_cycles;
+        int64_t perf_time_us;
+
+        void * data;
+        char padding[8];
+    };
+
+    // computation graph
+    struct ggml_cgraph {
+        int n_nodes;
+        int n_leafs;
+        int n_threads;
+
+        size_t work_size;
+        struct ggml_tensor * work;
+
+        struct ggml_tensor * nodes[GGML_MAX_NODES];
+        struct ggml_tensor * grads[GGML_MAX_NODES];
+        struct ggml_tensor * leafs[GGML_MAX_NODES];
+
+        // performance
+        int     perf_runs;
+        int64_t perf_cycles;
+        int64_t perf_time_us;
+    };
+
+    // scratch buffer
+    struct ggml_scratch {
+        size_t offs;
+        size_t size;
+        void * data;
+    };
 
-struct ggml_init_params {
-    // memory pool
-    size_t mem_size;   // bytes
-    void * mem_buffer; // if NULL, memory will be allocated internally
-    bool   no_alloc;   // don't allocate memory for the tensor data
-};
+    struct ggml_init_params {
+        // memory pool
+        size_t mem_size;   // bytes
+        void * mem_buffer; // if NULL, memory will be allocated internally
+        bool   no_alloc;   // don't allocate memory for the tensor data
+    };
 
-void    ggml_time_init(void); // call this once at the beginning of the program
-int64_t ggml_time_ms(void);
-int64_t ggml_time_us(void);
-int64_t ggml_cycles(void);
-int64_t ggml_cycles_per_ms(void);
+    // misc
 
-void ggml_print_object (const struct ggml_object * obj);
-void ggml_print_objects(const struct ggml_context * ctx);
+    GGML_API void    ggml_time_init(void); // call this once at the beginning of the program
+    GGML_API int64_t ggml_time_ms(void);
+    GGML_API int64_t ggml_time_us(void);
+    GGML_API int64_t ggml_cycles(void);
+    GGML_API int64_t ggml_cycles_per_ms(void);
 
-int64_t ggml_nelements(const struct ggml_tensor * tensor);
-size_t  ggml_nbytes   (const struct ggml_tensor * tensor);
+    GGML_API void    ggml_print_object (const struct ggml_object * obj);
+    GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
 
-int    ggml_blck_size (enum ggml_type type);
-size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
-float  ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
+    GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes   (const struct ggml_tensor * tensor);
 
-const char * ggml_type_name(enum ggml_type type);
+    GGML_API int     ggml_blck_size (enum ggml_type type);
+    GGML_API size_t  ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
+    GGML_API float   ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
 
-size_t ggml_element_size(const struct ggml_tensor * tensor);
+    GGML_API const char * ggml_type_name(enum ggml_type type);
 
-bool ggml_is_quantized(enum ggml_type type);
+    GGML_API size_t  ggml_element_size(const struct ggml_tensor * tensor);
 
-struct ggml_context * ggml_init(struct ggml_init_params params);
-void ggml_free(struct ggml_context * ctx);
+    GGML_API bool    ggml_is_quantized(enum ggml_type type);
 
-size_t ggml_used_mem(const struct ggml_context * ctx);
+    // main
 
-size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
+    GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
+    GGML_API void    ggml_free(struct ggml_context * ctx);
 
-struct ggml_tensor * ggml_new_tensor(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int    n_dims,
-        const int64_t *ne);
-
-struct ggml_tensor * ggml_new_tensor_1d(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int64_t ne0);
-
-struct ggml_tensor * ggml_new_tensor_2d(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int64_t ne0,
-        int64_t ne1);
-
-struct ggml_tensor * ggml_new_tensor_3d(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int64_t ne0,
-        int64_t ne1,
-        int64_t ne2);
-
-struct ggml_tensor * ggml_new_tensor_4d(
-        struct ggml_context * ctx,
-        enum   ggml_type type,
-        int64_t ne0,
-        int64_t ne1,
-        int64_t ne2,
-        int64_t ne3);
-
-struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
-struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
-
-struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
-struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
-
-struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
-struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
-struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
-
-int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
-void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
-
-float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
-void  ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
-
- void * ggml_get_data    (const struct ggml_tensor * tensor);
-float * ggml_get_data_f32(const struct ggml_tensor * tensor);
-
-//
-// operations on tensors with backpropagation
-//
-
-struct ggml_tensor * ggml_dup(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_add(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
+    GGML_API size_t  ggml_used_mem(const struct ggml_context * ctx);
 
+    GGML_API size_t  ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
 
-struct ggml_tensor * ggml_add_inplace(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
+    GGML_API struct ggml_tensor * ggml_new_tensor(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int    n_dims,
+            const int64_t *ne);
 
-struct ggml_tensor * ggml_sub(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
+    GGML_API struct ggml_tensor * ggml_new_tensor_1d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0);
 
-struct ggml_tensor * ggml_mul(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
+    GGML_API struct ggml_tensor * ggml_new_tensor_2d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1);
 
-struct ggml_tensor * ggml_div(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_sqr(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_sqrt(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// return scalar
-// TODO: compute sum along rows
-struct ggml_tensor * ggml_sum(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// mean along rows
-struct ggml_tensor * ggml_mean(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// if a is the same shape as b, and a is not parameter, return a
-// otherwise, return a new tensor: repeat(a) to fit in b
-struct ggml_tensor * ggml_repeat(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_abs(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_sgn(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_neg(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_step(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_relu(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// TODO: double-check this computation is correct
-struct ggml_tensor * ggml_gelu(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_silu(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// normalize along rows
-// TODO: eps is hardcoded to 1e-5 for now
-struct ggml_tensor * ggml_norm(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_rms_norm(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// A: m rows, n columns
-// B: p rows, n columns (i.e. we transpose it internally)
-// result is m columns, p rows
-struct ggml_tensor * ggml_mul_mat(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-//
-// operations on tensors without backpropagation
-//
-
-// in-place, returns view(a)
-struct ggml_tensor * ggml_scale(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-// a -> b, return view(b)
-struct ggml_tensor * ggml_cpy(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-// make contiguous
-struct ggml_tensor * ggml_cont(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// return view(a), b specifies the new shape
-// TODO: when we start computing gradient, make a copy instead of view
-struct ggml_tensor * ggml_reshape(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-// return view(a)
-// TODO: when we start computing gradient, make a copy instead of view
-struct ggml_tensor * ggml_reshape_2d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        int64_t               ne1);
-
-// return view(a)
-// TODO: when we start computing gradient, make a copy instead of view
-struct ggml_tensor * ggml_reshape_3d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        int64_t               ne1,
-        int64_t               ne2);
-
-// offset in bytes
-struct ggml_tensor * ggml_view_1d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        size_t                offset);
-
-struct ggml_tensor * ggml_view_2d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        int64_t               ne1,
-        size_t                nb1, // row stride in bytes
-        size_t                offset);
-
-struct ggml_tensor * ggml_view_3d(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int64_t               ne0,
-        int64_t               ne1,
-        int64_t               ne2,
-        size_t                nb1, // row   stride in bytes
-        size_t                nb2, // slice stride in bytes
-        size_t                offset);
-
-struct ggml_tensor * ggml_permute(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int                   axis0,
-        int                   axis1,
-        int                   axis2,
-        int                   axis3);
-
-// alias for ggml_permute(ctx, a, 1, 0, 2, 3)
-struct ggml_tensor * ggml_transpose(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-struct ggml_tensor * ggml_get_rows(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-// set elements above the diagonal to -INF
-// in-place, returns view(a)
-struct ggml_tensor * ggml_diag_mask_inf(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int                   n_past);
-
-// in-place, returns view(a)
-struct ggml_tensor * ggml_soft_max(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a);
-
-// rotary position embedding
-// in-place, returns view(a)
-// if mode & 1 == 1, skip n_past elements
-// if mode & 2 == 1, GPT-NeoX style
-// TODO: avoid creating a new tensor every time
-struct ggml_tensor * ggml_rope(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int                   n_past,
-        int                   n_dims,
-        int                   mode);
-
-// padding = 1
-// TODO: we don't support extra parameters for now
-//       that's why we are hard-coding the stride, padding, and dilation
-//       not great ..
-struct ggml_tensor * ggml_conv_1d_1s(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_conv_1d_2s(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b);
-
-struct ggml_tensor * ggml_flash_attn(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * q,
-        struct ggml_tensor  * k,
-        struct ggml_tensor  * v,
-        bool                  masked);
-
-struct ggml_tensor * ggml_flash_ff(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        struct ggml_tensor  * b0,
-        struct ggml_tensor  * b1,
-        struct ggml_tensor  * c0,
-        struct ggml_tensor  * c1);
-
-// Mapping operations
-typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
-typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
-
-struct ggml_tensor * ggml_map_unary_f32(
-        struct ggml_context        * ctx,
-        struct ggml_tensor         * a,
-        const  ggml_unary_op_f32_t fun);
-
-struct ggml_tensor * ggml_map_binary_f32(
-        struct ggml_context         * ctx,
-        struct ggml_tensor          * a,
-        struct ggml_tensor          * b,
-        const  ggml_binary_op_f32_t fun);
-
-//
-// automatic differentiation
-//
-
-void ggml_set_param(
-        struct ggml_context * ctx,
-        struct ggml_tensor * tensor);
-
-void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
-
-struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
-struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
-
-void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
-void ggml_graph_reset  (struct ggml_cgraph * cgraph);
-
-// print info and performance information for the graph
-void ggml_graph_print(const struct ggml_cgraph * cgraph);
-
-// dump the graph into a file using the dot format
-void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
-
-//
-// optimization
-//
-
-// optimization methods
-enum ggml_opt_type {
-    GGML_OPT_ADAM,
-    GGML_OPT_LBFGS,
-};
-
-// linesearch methods
-enum ggml_linesearch {
-    GGML_LINESEARCH_DEFAULT = 1,
-
-    GGML_LINESEARCH_BACKTRACKING_ARMIJO       = 0,
-    GGML_LINESEARCH_BACKTRACKING_WOLFE        = 1,
-    GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
-};
-
-// optimization return values
-enum ggml_opt_result {
-    GGML_OPT_OK = 0,
-    GGML_OPT_DID_NOT_CONVERGE,
-    GGML_OPT_NO_CONTEXT,
-    GGML_OPT_INVALID_WOLFE,
-    GGML_OPT_FAIL,
+    GGML_API struct ggml_tensor * ggml_new_tensor_3d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1,
+            int64_t ne2);
 
-    GGML_LINESEARCH_FAIL = -128,
-    GGML_LINESEARCH_MINIMUM_STEP,
-    GGML_LINESEARCH_MAXIMUM_STEP,
-    GGML_LINESEARCH_MAXIMUM_ITERATIONS,
-    GGML_LINESEARCH_INVALID_PARAMETERS,
-};
+    GGML_API struct ggml_tensor * ggml_new_tensor_4d(
+            struct ggml_context * ctx,
+            enum   ggml_type type,
+            int64_t ne0,
+            int64_t ne1,
+            int64_t ne2,
+            int64_t ne3);
 
-// optimization parameters
-//
-//   see ggml.c (ggml_opt_default_params) for default values
-//
-struct ggml_opt_params {
-    enum ggml_opt_type type;
+    GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
+    GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
+
+    GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
+    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
+
+    GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
+    GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
+    GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
+
+    GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_API void    ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
+
+    GGML_API float   ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
+    GGML_API void    ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
 
-    int n_threads;
+    GGML_API void *  ggml_get_data    (const struct ggml_tensor * tensor);
+    GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
 
-    // delta-based convergence test
     //
-    //   if past == 0 - disabled
-    //   if past > 0:
-    //     stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+    // operations on tensors with backpropagation
     //
-    int past;
-    float delta;
 
-    // maximum number of iterations without improvement
+    GGML_API struct ggml_tensor * ggml_dup(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_add(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_add_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_sub(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_mul(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_div(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_sqr(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sqrt(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // return scalar
+    // TODO: compute sum along rows
+    GGML_API struct ggml_tensor * ggml_sum(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // mean along rows
+    GGML_API struct ggml_tensor * ggml_mean(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // if a is the same shape as b, and a is not parameter, return a
+    // otherwise, return a new tensor: repeat(a) to fit in b
+    GGML_API struct ggml_tensor * ggml_repeat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_abs(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sgn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_neg(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_step(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_relu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // TODO: double-check this computation is correct
+    GGML_API struct ggml_tensor * ggml_gelu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_silu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // normalize along rows
+    // TODO: eps is hardcoded to 1e-5 for now
+    GGML_API struct ggml_tensor * ggml_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_rms_norm(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // A: m rows, n columns
+    // B: p rows, n columns (i.e. we transpose it internally)
+    // result is m columns, p rows
+    GGML_API struct ggml_tensor * ggml_mul_mat(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
     //
-    //   if 0 - disabled
-    //   if > 0:
-    //     assume convergence if no cost improvement in this number of iterations
+    // operations on tensors without backpropagation
     //
-    int max_no_improvement;
 
-    bool print_forward_graph;
-    bool print_backward_graph;
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_scale(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // a -> b, return view(b)
+    GGML_API struct ggml_tensor * ggml_cpy(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // make contiguous
+    GGML_API struct ggml_tensor * ggml_cont(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // return view(a), b specifies the new shape
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // return view(a)
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1);
+
+    // return view(a)
+    // TODO: when we start computing gradient, make a copy instead of view
+    GGML_API struct ggml_tensor * ggml_reshape_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2);
+
+    // offset in bytes
+    GGML_API struct ggml_tensor * ggml_view_1d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_2d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            size_t                nb1, // row stride in bytes
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_view_3d(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int64_t               ne0,
+            int64_t               ne1,
+            int64_t               ne2,
+            size_t                nb1, // row   stride in bytes
+            size_t                nb2, // slice stride in bytes
+            size_t                offset);
+
+    GGML_API struct ggml_tensor * ggml_permute(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   axis0,
+            int                   axis1,
+            int                   axis2,
+            int                   axis3);
+
+    // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
+    GGML_API struct ggml_tensor * ggml_transpose(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_get_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    // set elements above the diagonal to -INF
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_diag_mask_inf(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past);
+
+    // in-place, returns view(a)
+    GGML_API struct ggml_tensor * ggml_soft_max(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    // rotary position embedding
+    // in-place, returns view(a)
+    // if mode & 1 == 1, skip n_past elements
+    // if mode & 2 == 1, GPT-NeoX style
+    // TODO: avoid creating a new tensor every time
+    GGML_API struct ggml_tensor * ggml_rope(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            int                   n_past,
+            int                   n_dims,
+            int                   mode);
+
+    // padding = 1
+    // TODO: we don't support extra parameters for now
+    //       that's why we are hard-coding the stride, padding, and dilation
+    //       not great ..
+    GGML_API struct ggml_tensor * ggml_conv_1d_1s(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_conv_1d_2s(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b);
+
+    GGML_API struct ggml_tensor * ggml_flash_attn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            bool                  masked);
+
+    GGML_API struct ggml_tensor * ggml_flash_ff(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b0,
+            struct ggml_tensor  * b1,
+            struct ggml_tensor  * c0,
+            struct ggml_tensor  * c1);
+
+    // Mapping operations
+    GGML_API typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
+    GGML_API typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
+
+    GGML_API struct ggml_tensor * ggml_map_unary_f32(
+            struct ggml_context        * ctx,
+            struct ggml_tensor         * a,
+            const  ggml_unary_op_f32_t fun);
+
+    GGML_API struct ggml_tensor * ggml_map_binary_f32(
+            struct ggml_context         * ctx,
+            struct ggml_tensor          * a,
+            struct ggml_tensor          * b,
+            const  ggml_binary_op_f32_t fun);
 
-    // ADAM parameters
-    struct {
-        int n_iter;
+    //
+    // automatic differentiation
+    //
 
-        float alpha; // learning rate
-        float beta1;
-        float beta2;
-        float eps;   // epsilon for numerical stability
-        float eps_f; // epsilon for convergence test
-        float eps_g; // epsilon for convergence test
-    } adam;
+    GGML_API void ggml_set_param(
+            struct ggml_context * ctx,
+            struct ggml_tensor * tensor);
 
-    // LBFGS parameters
-    struct {
-        int m; // number of corrections to approximate the inv. Hessian
-        int n_iter;
-        int max_linesearch;
+    GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
 
-        float eps;      // convergence tolerance
-        float ftol;     // line search tolerance
-        float wolfe;
-        float min_step;
-        float max_step;
+    GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
+    GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
 
-        enum ggml_linesearch linesearch;
-    } lbfgs;
-};
+    GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
+    GGML_API void ggml_graph_reset  (struct ggml_cgraph * cgraph);
 
-struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+    // print info and performance information for the graph
+    GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
 
-// optimize the function defined by the tensor f
-enum ggml_opt_result ggml_opt(
-        struct ggml_context * ctx,
-        struct ggml_opt_params params,
-        struct ggml_tensor * f);
+    // dump the graph into a file using the dot format
+    GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
 
-//
-// quantization
-//
+    //
+    // optimization
+    //
 
-size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
-size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist);
+    // optimization methods
+    enum ggml_opt_type {
+        GGML_OPT_ADAM,
+        GGML_OPT_LBFGS,
+    };
+
+    // linesearch methods
+    enum ggml_linesearch {
+        GGML_LINESEARCH_DEFAULT = 1,
+
+        GGML_LINESEARCH_BACKTRACKING_ARMIJO       = 0,
+        GGML_LINESEARCH_BACKTRACKING_WOLFE        = 1,
+        GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
+    };
+
+    // optimization return values
+    enum ggml_opt_result {
+        GGML_OPT_OK = 0,
+        GGML_OPT_DID_NOT_CONVERGE,
+        GGML_OPT_NO_CONTEXT,
+        GGML_OPT_INVALID_WOLFE,
+        GGML_OPT_FAIL,
+
+        GGML_LINESEARCH_FAIL = -128,
+        GGML_LINESEARCH_MINIMUM_STEP,
+        GGML_LINESEARCH_MAXIMUM_STEP,
+        GGML_LINESEARCH_MAXIMUM_ITERATIONS,
+        GGML_LINESEARCH_INVALID_PARAMETERS,
+    };
+
+    // optimization parameters
+    //
+    //   see ggml.c (ggml_opt_default_params) for default values
+    //
+    struct ggml_opt_params {
+        enum ggml_opt_type type;
+
+        int n_threads;
+
+        // delta-based convergence test
+        //
+        //   if past == 0 - disabled
+        //   if past > 0:
+        //     stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
+        //
+        int past;
+        float delta;
+
+        // maximum number of iterations without improvement
+        //
+        //   if 0 - disabled
+        //   if > 0:
+        //     assume convergence if no cost improvement in this number of iterations
+        //
+        int max_no_improvement;
+
+        bool print_forward_graph;
+        bool print_backward_graph;
+
+        // ADAM parameters
+        struct {
+            int n_iter;
+
+            float alpha; // learning rate
+            float beta1;
+            float beta2;
+            float eps;   // epsilon for numerical stability
+            float eps_f; // epsilon for convergence test
+            float eps_g; // epsilon for convergence test
+        } adam;
+
+        // LBFGS parameters
+        struct {
+            int m; // number of corrections to approximate the inv. Hessian
+            int n_iter;
+            int max_linesearch;
+
+            float eps;      // convergence tolerance
+            float ftol;     // line search tolerance
+            float wolfe;
+            float min_step;
+            float max_step;
+
+            enum ggml_linesearch linesearch;
+        } lbfgs;
+    };
+
+    GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
+
+    // optimize the function defined by the tensor f
+    GGML_API enum ggml_opt_result ggml_opt(
+            struct ggml_context * ctx,
+            struct ggml_opt_params params,
+            struct ggml_tensor * f);
 
-size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
+    //
+    // quantization
+    //
 
-//
-// system info
-//
+    GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
+    GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist);
 
-int ggml_cpu_has_avx(void);
-int ggml_cpu_has_avx2(void);
-int ggml_cpu_has_avx512(void);
-int ggml_cpu_has_avx512_vbmi(void);
-int ggml_cpu_has_avx512_vnni(void);
-int ggml_cpu_has_fma(void);
-int ggml_cpu_has_neon(void);
-int ggml_cpu_has_arm_fma(void);
-int ggml_cpu_has_f16c(void);
-int ggml_cpu_has_fp16_va(void);
-int ggml_cpu_has_wasm_simd(void);
-int ggml_cpu_has_blas(void);
-int ggml_cpu_has_cublas(void);
-int ggml_cpu_has_sse3(void);
-int ggml_cpu_has_vsx(void);
+    GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
 
+    //
+    // system info
+    //
 
-//
-// Internal types and functions exposed for tests and benchmarks
-//
+    GGML_API int ggml_cpu_has_avx        (void);
+    GGML_API int ggml_cpu_has_avx2       (void);
+    GGML_API int ggml_cpu_has_avx512     (void);
+    GGML_API int ggml_cpu_has_avx512_vbmi(void);
+    GGML_API int ggml_cpu_has_avx512_vnni(void);
+    GGML_API int ggml_cpu_has_fma        (void);
+    GGML_API int ggml_cpu_has_neon       (void);
+    GGML_API int ggml_cpu_has_arm_fma    (void);
+    GGML_API int ggml_cpu_has_f16c       (void);
+    GGML_API int ggml_cpu_has_fp16_va    (void);
+    GGML_API int ggml_cpu_has_wasm_simd  (void);
+    GGML_API int ggml_cpu_has_blas       (void);
+    GGML_API int ggml_cpu_has_cublas     (void);
+    GGML_API int ggml_cpu_has_sse3       (void);
+    GGML_API int ggml_cpu_has_vsx        (void);
+
+
+    //
+    // Internal types and functions exposed for tests and benchmarks
+    //
 
 #ifdef  __cplusplus
-// restrict not standard in C++
+    // restrict not standard in C++
 #define GGML_RESTRICT
 #else
 #define GGML_RESTRICT restrict
 #endif
-typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
-typedef void (*quantize_row_q_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
-typedef void (*vec_dot_q_t)(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
-
-typedef struct {
-    dequantize_row_q_t dequantize_row_q;
-    quantize_row_q_t   quantize_row_q;
-    quantize_row_q_t   quantize_row_q_reference;
-    quantize_row_q_t   quantize_row_q_dot;
-    vec_dot_q_t        vec_dot_q;
-} quantize_fns_t;
-
-quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
+    typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
+    typedef void (*quantize_row_q_t)  (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
+    typedef void (*vec_dot_q_t)       (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
+
+    typedef struct {
+        dequantize_row_q_t dequantize_row_q;
+        quantize_row_q_t   quantize_row_q;
+        quantize_row_q_t   quantize_row_q_reference;
+        quantize_row_q_t   quantize_row_q_dot;
+        vec_dot_q_t        vec_dot_q;
+    } quantize_fns_t;
+
+    quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
 
 #ifdef  __cplusplus
 }