]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : sync latest llama.cpp (view_src + alloc improvements) (#1247)
authorGeorgi Gerganov <redacted>
Tue, 5 Sep 2023 17:57:27 +0000 (20:57 +0300)
committerGitHub <redacted>
Tue, 5 Sep 2023 17:57:27 +0000 (20:57 +0300)
* ggml : sync latest llama.cpp (view_src + alloc improvements)

* ggml : fix build

ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml-opencl.cpp
ggml.c
ggml.h

index 391e63181c0fc688e686eef2434a75ec7b05674a..d2dbf824ef2da9145161d1f43dbfedd7e107ce6e 100644 (file)
 #if defined(GGML_USE_HIPBLAS)
 #define __CUDA_ARCH__ 1300
 
+#ifndef __has_builtin
+    #define __has_builtin(x) 0
+#endif
+
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
 static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
     const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
     const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+#if __has_builtin(__builtin_elementwise_sub_sat)
     const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
     return reinterpret_cast<const int&>(c);
+#else
+    int8x4_t c;
+    int16_t tmp;
+#pragma unroll
+    for (int i = 0; i < 4; i++) {
+        tmp = va[i] - vb[i];
+        if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
+        if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
+        c[i] = tmp;
+    }
+    return reinterpret_cast<int&>(c);
+#endif // __has_builtin(__builtin_elementwise_sub_sat)
 }
 
 static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
@@ -447,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+    }
+    return a;
+}
+
+template <int block_size>
 static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
     const float eps = 1e-5f;
 
-    float mean = 0.0f;
-    float var = 0.0f;
+    float2 mean_var = make_float2(0.f, 0.f);
 
-    for (int col = tid; col < ncols; col += WARP_SIZE) {
+    for (int col = tid; col < ncols; col += block_size) {
         const float xi = x[row*ncols + col];
-        mean += xi;
-        var += xi * xi;
+        mean_var.x += xi;
+        mean_var.y += xi * xi;
     }
 
     // sum up partial sums
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
-        var += __shfl_xor_sync(0xffffffff, var, mask, 32);
+    mean_var = warp_reduce_sum(mean_var);
+    if (block_size > WARP_SIZE) {
+        __shared__ float2 s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = mean_var;
+        }
+        __syncthreads();
+        mean_var = s_sum[lane_id];
+        mean_var = warp_reduce_sum(mean_var);
     }
 
-    mean /= ncols;
-    var = var / ncols - mean * mean;
-    const float inv_var = rsqrtf(var + eps);
+    const float mean = mean_var.x / ncols;
+    const float var = mean_var.y / ncols - mean * mean;
+    const float inv_std = rsqrtf(var + eps);
 
-    for (int col = tid; col < ncols; col += WARP_SIZE) {
-        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
     }
 }
 
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    }
+    return x;
+}
+
+template <int block_size>
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
     float tmp = 0.0f; // partial sum for thread in warp
 
-    for (int col = tid; col < ncols; col += WARP_SIZE) {
+    for (int col = tid; col < ncols; col += block_size) {
         const float xi = x[row*ncols + col];
         tmp += xi * xi;
     }
 
     // sum up partial sums
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
     }
 
     const float mean = tmp / ncols;
     const float scale = rsqrtf(mean + eps);
 
-    for (int col = tid; col < ncols; col += WARP_SIZE) {
+    for (int col = tid; col < ncols; col += block_size) {
         dst[row*ncols + col] = scale * x[row*ncols + col];
     }
 }
@@ -4186,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
 
 static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+    }
 }
 
 static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    }
 }
 
 static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
@@ -5721,7 +5781,6 @@ inline void ggml_cuda_op_alibi(
     (void) src1;
     (void) src0_ddq_i;
     (void) src1_ddf_i;
-    (void) i02;
     (void) i1;
 }
 
index ad2ee8cf5fef098169d765b7fa92b12f57c9f70c..d0d23442eab5d83b74768d0b74bef674861d5e06 100644 (file)
@@ -11,6 +11,7 @@
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+// TODO: temporary - reuse llama.cpp logging
 #ifdef GGML_METAL_NDEBUG
 #define metal_printf(...)
 #else
@@ -75,6 +76,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
     GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
     GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -113,12 +115,26 @@ static NSString * const msl_library_source = @"see metal.metal";
 @end
 
 struct ggml_metal_context * ggml_metal_init(int n_cb) {
-    fprintf(stderr, "%s: allocating\n", __func__);
+    metal_printf("%s: allocating\n", __func__);
+
+    // Show all the Metal device instances in the system
+    NSArray * devices = MTLCopyAllDevices();
+    id <MTLDevice> device;
+    NSString * s;
+    for (device in devices) {
+        s = [device name];
+        metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
+    }
 
-    struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+    // Pick and show default Metal device
+    device = MTLCreateSystemDefaultDevice();
+    s = [device name];
+    metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]);
 
+    // Configure context
+    struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+    ctx->device = device;
     ctx->n_cb   = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
-    ctx->device = MTLCreateSystemDefaultDevice();
     ctx->queue  = [ctx->device newCommandQueue];
     ctx->n_buffers = 0;
     ctx->concur_list_len = 0;
@@ -132,7 +148,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
         ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error];
         if (error) {
-            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
             return NULL;
         }
     }
@@ -146,11 +162,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
         NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
         NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
-        fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]);
+        metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
 
         NSString * src  = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
         if (error) {
-            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
             return NULL;
         }
 
@@ -162,7 +178,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
 #endif
         if (error) {
-            fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
+            metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
             return NULL;
         }
     }
@@ -174,11 +190,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #define GGML_METAL_ADD_KERNEL(name) \
         ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
         ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
-        fprintf(stderr, "%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
+        metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
                 (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
                 (int) ctx->pipeline_##name.threadExecutionWidth); \
         if (error) { \
-            fprintf(stderr, "%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
+            metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
             return NULL; \
         }
 
@@ -204,6 +220,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
         GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
+        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
         GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -230,19 +247,19 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 #undef GGML_METAL_ADD_KERNEL
     }
 
-    fprintf(stderr, "%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
-    fprintf(stderr, "%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    metal_printf("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
+    metal_printf("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
     if (ctx->device.maxTransferRate != 0) {
-        fprintf(stderr, "%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
+        metal_printf("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
     } else {
-        fprintf(stderr, "%s: maxTransferRate               = built-in GPU\n", __func__);
+        metal_printf("%s: maxTransferRate               = built-in GPU\n", __func__);
     }
 
     return ctx;
 }
 
 void ggml_metal_free(struct ggml_metal_context * ctx) {
-    fprintf(stderr, "%s: deallocating\n", __func__);
+    metal_printf("%s: deallocating\n", __func__);
 #define GGML_METAL_DEL_KERNEL(name) \
     [ctx->function_##name release]; \
     [ctx->pipeline_##name release];
@@ -269,6 +286,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
     GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
+    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
     GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
     GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -311,7 +329,7 @@ void * ggml_metal_host_malloc(size_t n) {
     void * data = NULL;
     const int result = posix_memalign((void **) &data, getpagesize(), n);
     if (result != 0) {
-        fprintf(stderr, "%s: error: posix_memalign failed\n", __func__);
+        metal_printf("%s: error: posix_memalign failed\n", __func__);
         return NULL;
     }
 
@@ -339,7 +357,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx) {
 // Metal buffer based on the host memory pointer
 //
 static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
-    //fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
+    //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
 
     const int64_t tsize = ggml_nbytes(t);
 
@@ -350,13 +368,13 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
         if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
             *offs = (size_t) ioffs;
 
-            //fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
+            //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
 
             return ctx->buffers[i].metal;
         }
     }
 
-    fprintf(stderr, "%s: error: buffer is nil\n", __func__);
+    metal_printf("%s: error: buffer is nil\n", __func__);
 
     return nil;
 }
@@ -368,7 +386,7 @@ bool ggml_metal_add_buffer(
                          size_t   size,
                          size_t   max_size) {
     if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
-        fprintf(stderr, "%s: too many buffers\n", __func__);
+        metal_printf("%s: too many buffers\n", __func__);
         return false;
     }
 
@@ -378,7 +396,7 @@ bool ggml_metal_add_buffer(
             const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
 
             if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
-                fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
+                metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
                 return false;
             }
         }
@@ -399,11 +417,11 @@ bool ggml_metal_add_buffer(
             ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
             if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
+                metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
                 return false;
             }
 
-            fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
+            metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
 
             ++ctx->n_buffers;
         } else {
@@ -423,27 +441,27 @@ bool ggml_metal_add_buffer(
                 ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
 
                 if (ctx->buffers[ctx->n_buffers].metal == nil) {
-                    fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
+                    metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
                     return false;
                 }
 
-                fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
+                metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
                 if (i + size_step < size) {
-                    fprintf(stderr, "\n");
+                    metal_printf("\n");
                 }
 
                 ++ctx->n_buffers;
             }
         }
 
-        fprintf(stderr, ", (%8.2f / %8.2f)",
+        metal_printf(", (%8.2f / %8.2f)",
                 ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
                 ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
 
         if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
-            fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
+            metal_printf(", warning: current allocated size is greater than the recommended max working set size\n");
         } else {
-            fprintf(stderr, "\n");
+            metal_printf("\n");
         }
     }
 
@@ -453,8 +471,6 @@ bool ggml_metal_add_buffer(
 void ggml_metal_set_tensor(
         struct ggml_metal_context * ctx,
         struct ggml_tensor * t) {
-    metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
-
     size_t offs;
     id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
 
@@ -464,8 +480,6 @@ void ggml_metal_set_tensor(
 void ggml_metal_get_tensor(
         struct ggml_metal_context * ctx,
         struct ggml_tensor * t) {
-    metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
-
     size_t offs;
     id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
 
@@ -560,15 +574,13 @@ void ggml_metal_graph_find_concurrency(
     }
 
     if (ctx->concur_list_len > GGML_MAX_CONCUR) {
-        fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
+        metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__);
     }
 }
 
 void ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
                struct ggml_cgraph * gf) {
-    metal_printf("%s: evaluating graph\n", __func__);
-
     @autoreleasepool {
 
     // if there is ctx->concur_list, dispatch concurrently
@@ -616,7 +628,7 @@ void ggml_metal_graph_compute(
                     continue;
                 }
 
-                metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
+                //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
 
                 struct ggml_tensor * src0 = gf->nodes[i]->src[0];
                 struct ggml_tensor * src1 = gf->nodes[i]->src[1];
@@ -685,6 +697,12 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_ADD:
                         {
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+
+                            // utilize float4
+                            GGML_ASSERT(ne00 % 4 == 0);
+                            const int64_t nb = ne00/4;
+
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
                                 [encoder setComputePipelineState:ctx->pipeline_add_row];
@@ -694,14 +712,20 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                            [encoder setBytes:&nb     length:sizeof(nb) atIndex:3];
 
-                            const int64_t n = ggml_nelements(dst);
+                            const int64_t n = ggml_nelements(dst)/4;
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
                     case GGML_OP_MUL:
                         {
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+
+                            // utilize float4
+                            GGML_ASSERT(ne00 % 4 == 0);
+                            const int64_t nb = ne00/4;
+
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
                                 [encoder setComputePipelineState:ctx->pipeline_mul_row];
@@ -711,9 +735,9 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                            [encoder setBytes:&nb     length:sizeof(nb) atIndex:3];
 
-                            const int64_t n = ggml_nelements(dst);
+                            const int64_t n = ggml_nelements(dst)/4;
 
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
@@ -764,7 +788,7 @@ void ggml_metal_graph_compute(
                                 } break;
                             default:
                                 {
-                                    fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                                    metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
                                     GGML_ASSERT(false);
                                 }
                         } break;
@@ -845,9 +869,13 @@ void ggml_metal_graph_compute(
                                 switch (src0t) {
                                     case GGML_TYPE_F16:
                                         {
-                                            nth0 = 64;
+                                            nth0 = 32;
                                             nth1 = 1;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                            if (ne11 * ne12 < 4) {
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+                                            } else {
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                            }
                                         } break;
                                     case GGML_TYPE_Q4_0:
                                         {
@@ -899,8 +927,8 @@ void ggml_metal_graph_compute(
                                             GGML_ASSERT(ne02 == 1);
                                             GGML_ASSERT(ne12 == 1);
 
-                                            nth0 = 2;
-                                            nth1 = 32;
+                                            nth0 = 4; //1;
+                                            nth1 = 8; //32;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
@@ -923,7 +951,7 @@ void ggml_metal_graph_compute(
                                         } break;
                                     default:
                                         {
-                                            fprintf(stderr, "Asserting on type %d\n",(int)src0t);
+                                            metal_printf("Asserting on type %d\n",(int)src0t);
                                             GGML_ASSERT(false && "not implemented");
                                         }
                                 };
@@ -948,9 +976,12 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
-                                    src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
+                                    src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
+                                else if (src0t == GGML_TYPE_Q4_K) {
+                                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                }
                                 else if (src0t == GGML_TYPE_Q3_K) {
 #ifdef GGML_QKK_64
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -964,8 +995,8 @@ void ggml_metal_graph_compute(
                                 else if (src0t == GGML_TYPE_Q6_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 } else {
-                                    [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
-                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                                    int64_t ny = (ne11 + 3)/4;
+                                    [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                             }
                         } break;
@@ -1161,7 +1192,7 @@ void ggml_metal_graph_compute(
                         } break;
                     default:
                         {
-                            fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
+                            metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
                             GGML_ASSERT(false);
                         }
                 }
@@ -1186,7 +1217,7 @@ void ggml_metal_graph_compute(
 
         MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
         if (status != MTLCommandBufferStatusCompleted) {
-            fprintf(stderr, "%s: command buffer %d failed with status %lu\n", __func__, i, status);
+            metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status);
             GGML_ASSERT(false);
         }
     }
index 82e1a0c7aca06e124898dc1c8caa7064bafef261..119fcbeb623c11ca25a1c59c08564d7c3da32fba 100644 (file)
@@ -25,9 +25,9 @@ typedef struct {
 } block_q8_0;
 
 kernel void kernel_add(
-        device const float * src0,
-        device const float * src1,
-        device       float * dst,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] + src1[tpig];
 }
@@ -35,18 +35,18 @@ kernel void kernel_add(
 // assumption: src1 is a row
 // broadcast src1 into src0
 kernel void kernel_add_row(
-        device const float * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant   int64_t & nb,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] + src1[tpig % ne00];
+    dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
 
 kernel void kernel_mul(
-        device const float * src0,
-        device const float * src1,
-        device       float * dst,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * src1[tpig];
 }
@@ -54,12 +54,12 @@ kernel void kernel_mul(
 // assumption: src1 is a row
 // broadcast src1 into src0
 kernel void kernel_mul_row(
-        device const float * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
+        device const float4 * src0,
+        device const float4 * src1,
+        device       float4 * dst,
+        constant    int64_t & nb,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * src1[tpig % ne00];
+    dst[tpig] = src0[tpig] * src1[tpig % nb];
 }
 
 kernel void kernel_scale(
@@ -133,19 +133,24 @@ kernel void kernel_soft_max(
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    // broadcast
-    if (tpitg[0] == 0) {
-        buf[0] = buf[0];
-    }
+    //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
+    //               the loop, and when that is done, buf[0] has the correct (synchronized) value
+    //if (tpitg[0] == 0) {
+    //    buf[0] = buf[0];
+    //}
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
 
     const float max = buf[0];
 
     // parallel sum
     buf[tpitg[0]] = 0.0f;
     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
-        buf[tpitg[0]] += exp(psrc0[i00] - max);
+        const float exp_psrc0 = exp(psrc0[i00] - max);
+        buf[tpitg[0]] += exp_psrc0;
+        // Remember the result of exp here. exp is expensive, so we really do not
+        // whish to compute it twice.
+        pdst[i00] = exp_psrc0;
     }
 
     // reduce
@@ -157,17 +162,18 @@ kernel void kernel_soft_max(
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
-    // broadcast
-    if (tpitg[0] == 0) {
-        buf[0] = buf[0];
-    }
+    // broadcast - not needed, see above
+    //// broadcast
+    //if (tpitg[0] == 0) {
+    //    buf[0] = buf[0];
+    //}
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
 
     const float sum = buf[0];
 
     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
-        pdst[i00] = exp(psrc0[i00] - max) / sum;
+        pdst[i00] /= sum;
     }
 }
 
@@ -214,25 +220,27 @@ kernel void kernel_norm(
         }
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
-    // broadcast
-    if (tpitg == 0) {
-        sum[0] /= ne00;
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //// broadcast
+    //if (tpitg == 0) {
+    //    sum[0] /= ne00;
+    //}
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
     const float mean  = sum[0];
 
-    // recenter
+    // recenter and VARIANCE
     device float * y = dst + tgpig*ne00;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        y[i00] = x[i00] - mean;
-    }
-
-    // VARIANCE
-    // parallel sum
     sum[tpitg] = 0.0f;
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] - mean;
         sum[tpitg] += y[i00] * y[i00];
     }
+
+    //// VARIANCE
+    //// parallel sum
+    //sum[tpitg] = 0.0f;
+    //for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+    //    sum[tpitg] += y[i00] * y[i00];
+    //}
     // reduce
     threadgroup_barrier(mem_flags::mem_threadgroup);
     for (uint i = ntg/2; i > 0; i /= 2) {
@@ -241,11 +249,11 @@ kernel void kernel_norm(
         }
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
-    // broadcast
-    if (tpitg == 0) {
-        sum[0] /= ne00;
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    //// broadcast
+    //if (tpitg == 0) {
+    //    sum[0] /= ne00;
+    //}
+    //threadgroup_barrier(mem_flags::mem_threadgroup);
     const float variance = sum[0];
 
     const float scale = 1.0f/sqrt(variance + eps);
@@ -435,6 +443,8 @@ kernel void kernel_mul_mat_q4_1_f32(
      mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
 }
 
+#define NB_Q8_0 8
+
 kernel void kernel_mul_mat_q8_0_f32(
         device const  void * src0,
         device const float * src1,
@@ -463,30 +473,30 @@ kernel void kernel_mul_mat_q8_0_f32(
     device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
     device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
 
-    float yl[16];
+    float yl[NB_Q8_0];
     float sumf[nr]={0.f};
 
-    const int ix = tiisg/2;
-    const int il = tiisg%2;
+    const int ix = tiisg/4;
+    const int il = tiisg%4;
 
-    device const float * yb = y + ix * QK8_0 + 16*il;
+    device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
 
-    // each thread in a SIMD group deals with half a block.
-    for (int ib = ix; ib < nb; ib += nw/2) {
-        for (int i = 0; i < 16; ++i) {
+    // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+    for (int ib = ix; ib < nb; ib += nw/4) {
+        for (int i = 0; i < NB_Q8_0; ++i) {
             yl[i] = yb[i];
         }
 
         for (int row = 0; row < nr; row++) {
-            device const int8_t * qs = x[ib+row*nb].qs + 16*il;
+            device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
             float sumq = 0.f;
-            for (int iq = 0; iq < 16; ++iq) {
+            for (int iq = 0; iq < NB_Q8_0; ++iq) {
                 sumq += qs[iq] * yl[iq];
             }
             sumf[row] += sumq*x[ib+row*nb].d;
         }
 
-        yb += QK8_0 * 16;
+        yb += NB_Q8_0 * nw;
     }
 
     for (int row = 0; row < nr; ++row) {
@@ -497,7 +507,7 @@ kernel void kernel_mul_mat_q8_0_f32(
     }
 }
 
-kernel void kernel_mul_mat_f16_f32(
+kernel void kernel_mul_mat_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -515,11 +525,8 @@ kernel void kernel_mul_mat_f16_f32(
         constant  uint64_t & nb12,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
-        threadgroup float  * sum [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3  tpig[[thread_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3  tptg[[threads_per_threadgroup]]) {
+        uint tiisg[[thread_index_in_simdgroup]]) {
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
@@ -528,24 +535,102 @@ kernel void kernel_mul_mat_f16_f32(
     device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
-    sum[tpitg.x] = 0.0f;
-
-    for (int i = tpitg.x; i < ne00; i += tptg.x) {
-        sum[tpitg.x] += (float) x[i] * (float) y[i];
+    float sumf = 0;
+    if (ne00 < 128) {
+        for (int i = tiisg; i < ne00; i += 32) {
+            sumf += (float) x[i] * (float) y[i];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    } else {
+        device const half4  * x4 = (device const half4  *) x;
+        device const float4 * y4 = (device const float4 *) y;
+        for (int i = tiisg; i < ne00/4; i += 32) {
+            for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
     }
 
-    // accumulate the sum from all threads in the threadgroup
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = tptg.x/2; i > 0; i /= 2) {
-        if (tpitg.x < i) {
-            sum[tpitg.x] += sum[tpitg.x + i];
+}
+
+#define N_F16_F32 4
+
+kernel void kernel_mul_mat_f16_f32(
+        device const  char * src0,
+        device const  char * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01,
+        constant   int64_t & ne02,
+        constant  uint64_t & nb00,
+        constant  uint64_t & nb01,
+        constant  uint64_t & nb02,
+        constant   int64_t & ne10,
+        constant   int64_t & ne11,
+        constant   int64_t & ne12,
+        constant  uint64_t & nb10,
+        constant  uint64_t & nb11,
+        constant  uint64_t & nb12,
+        constant   int64_t & ne0,
+        constant   int64_t & ne1,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint tiisg[[thread_index_in_simdgroup]]) {
+
+    const int64_t r0 = tgpig.x;
+    const int64_t rb = tgpig.y*N_F16_F32;
+    const int64_t im = tgpig.z;
+
+    device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
         }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
-    }
+    } else {
+        device const half4 * x4 = (device const half4 *)x;
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
+            device const float4 * y4 = (device const float4 *) y;
 
-    if (tpitg.x == 0) {
-        dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
     }
+
 }
 
 kernel void kernel_alibi_f32(
@@ -1244,7 +1329,8 @@ kernel void kernel_mul_mat_q4_K_f32(
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int r2 = tgpig.z;
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = r0 * N_DST;
     const int ib_row = first_row * nb;
     const uint offset0 = r2/gqa*(nb*ne0);
     device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
index eb214a836489b04a4e3d6b9c6444ab30570edb89..777048d01115779f666457d70aeed3a833bc44b8 100644 (file)
@@ -1334,7 +1334,7 @@ void ggml_cl_free_data(const struct ggml_tensor* tensor) {
         return;
     }
 
-    cl_mem mem = (cl_mem)tensor->data;
+    cl_mem mem = (cl_mem)tensor->extra;
     clReleaseMemObject(mem);
 }
 
@@ -1393,7 +1393,7 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
     size_t d_size;
 
     cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
-    cl_mem d_Y = (cl_mem) src1->data; // src1 is already on device, broadcasted.
+    cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
     cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
 
 
@@ -1491,9 +1491,9 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
     size_t d_size;
     cl_mem d_X;
     if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
-        d_X = (cl_mem) src0->data;
+        d_X = (cl_mem) src0->extra;
     } else {
-        d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
+        d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size);
     }
     cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
     cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
@@ -1567,7 +1567,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
     size_t d_size;
     cl_mem d_X;
     if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
-        d_X = (cl_mem) src0->data;
+        d_X = (cl_mem) src0->extra;
     } else {
         d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
     }
@@ -1697,7 +1697,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
                 events.emplace_back();
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
             } else if (src0->backend == GGML_BACKEND_GPU) {
-                d_Q = (cl_mem) src0->data;
+                d_Q = (cl_mem) src0->extra;
             } else {
                 GGML_ASSERT(false);
             }
@@ -1860,6 +1860,6 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
 
     CL_CHECK(clFinish(queue));
 
-    tensor->data = dst;
+    tensor->extra = dst;
     GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
 }
diff --git a/ggml.c b/ggml.c
index ecbc724b508c23e1b67379cf2b779befd6247232..63a22223a4a003699527f813a617d3328bef9c17 100644 (file)
--- a/ggml.c
+++ b/ggml.c
 // disable "possible loss of data" to avoid hundreds of casts
 // we should just be careful :)
 #pragma warning(disable: 4244 4267)
+
+// disable POSIX deprecation warnigns
+// these functions are never going away, anyway
+#pragma warning(disable: 4996)
 #endif
 
 #if defined(_WIN32)
@@ -123,6 +127,8 @@ typedef void * thread_ret_t;
 #define GGML_GELU_FP16
 #define GGML_GELU_QUICK_FP16
 #define GGML_SILU_FP16
+// #define GGML_CROSS_ENTROPY_EXP_FP16
+// #define GGML_FLASH_ATTN_EXP_FP16
 
 #define GGML_SOFT_MAX_UNROLL 4
 #define GGML_VEC_DOT_UNROLL  2
@@ -186,8 +192,8 @@ typedef void * thread_ret_t;
 //
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
-#define GGML_ALIGNED_MALLOC(size)  _aligned_malloc(size, GGML_MEM_ALIGN)
-#define GGML_ALIGNED_FREE(ptr)     _aligned_free(ptr)
+#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
+#define GGML_ALIGNED_FREE(ptr)    _aligned_free(ptr)
 #else
 inline static void * ggml_aligned_malloc(size_t size) {
     void * aligned_memory = NULL;
@@ -212,8 +218,8 @@ inline static void * ggml_aligned_malloc(size_t size) {
     }
     return aligned_memory;
 }
-#define GGML_ALIGNED_MALLOC(size)  ggml_aligned_malloc(size)
-#define GGML_ALIGNED_FREE(ptr)     free(ptr)
+#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
+#define GGML_ALIGNED_FREE(ptr)    free(ptr)
 #endif
 
 #define UNUSED GGML_UNUSED
@@ -301,6 +307,10 @@ typedef double ggml_float;
 #endif
 #endif
 
+#ifdef __riscv_v_intrinsic
+#include <riscv_vector.h>
+#endif
+
 #ifdef __F16C__
 
 #ifdef _MSC_VER
@@ -665,7 +675,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
 }
 
 static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
-#ifdef __AVXVNNI__
+#if __AVXVNNI__
     const __m256i zero = _mm256_setzero_si256();
     const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
     return _mm256_cvtepi32_ps(summed_pairs);
@@ -678,7 +688,7 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy)
 
 // multiply int8_t, add results pairwise twice and return as float vector
 static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-#ifdef __AVXVNNIINT8__
+#if __AVXVNNIINT8__
     const __m256i zero = _mm256_setzero_si256();
     const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
     return _mm256_cvtepi32_ps(summed_pairs);
@@ -694,7 +704,7 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
 static inline __m128i packNibbles( __m256i bytes )
 {
     // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
-#ifdef __AVX512F__
+#if __AVX512F__
     const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4);   // 0000_0000_abcd_0000
     bytes = _mm256_or_si256(bytes, bytes_srli_4);               // 0000_abcd_abcd_efgh
     return _mm256_cvtepi16_epi8(bytes);                         // abcd_efgh
@@ -813,46 +823,6 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
 
 #if !defined(__aarch64__)
 
-inline static uint16_t vaddvq_u8(uint8x16_t v) {
-    return
-        (uint16_t)vgetq_lane_u8(v, 0)  + (uint16_t)vgetq_lane_u8(v, 1)  +
-        (uint16_t)vgetq_lane_u8(v, 2)  + (uint16_t)vgetq_lane_u8(v, 3)  +
-        (uint16_t)vgetq_lane_u8(v, 4)  + (uint16_t)vgetq_lane_u8(v, 5)  +
-        (uint16_t)vgetq_lane_u8(v, 6)  + (uint16_t)vgetq_lane_u8(v, 7)  +
-        (uint16_t)vgetq_lane_u8(v, 8)  + (uint16_t)vgetq_lane_u8(v, 9)  +
-        (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
-        (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
-        (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
-}
-
-inline static int16_t vaddvq_s8(int8x16_t v) {
-    return
-        (int16_t)vgetq_lane_s8(v, 0)  + (int16_t)vgetq_lane_s8(v, 1)  +
-        (int16_t)vgetq_lane_s8(v, 2)  + (int16_t)vgetq_lane_s8(v, 3)  +
-        (int16_t)vgetq_lane_s8(v, 4)  + (int16_t)vgetq_lane_s8(v, 5)  +
-        (int16_t)vgetq_lane_s8(v, 6)  + (int16_t)vgetq_lane_s8(v, 7)  +
-        (int16_t)vgetq_lane_s8(v, 8)  + (int16_t)vgetq_lane_s8(v, 9)  +
-        (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
-        (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
-        (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
-}
-
-inline static int32_t vaddvq_s16(int16x8_t v) {
-    return
-        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
-        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
-        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
-        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
-}
-
-inline static uint32_t vaddvq_u16(uint16x8_t v) {
-    return
-        (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
-        (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
-        (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
-        (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
-}
-
 inline static int32_t vaddvq_s32(int32x4_t v) {
     return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
 }
@@ -861,12 +831,6 @@ inline static float vaddvq_f32(float32x4_t v) {
     return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
 }
 
-inline static float vminvq_f32(float32x4_t v) {
-    return
-        MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
-            MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
-}
-
 inline static float vmaxvq_f32(float32x4_t v) {
     return
         MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
@@ -1294,7 +1258,6 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
 #endif
     }
 #else
-    (void)nb;
     // scalar
     quantize_row_q8_0_reference(x, y, k);
 #endif
@@ -1513,7 +1476,6 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
 #endif
     }
 #else
-    (void)nb;
     // scalar
     quantize_row_q8_1_reference(x, y, k);
 #endif
@@ -2679,6 +2641,41 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     }
 
     *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (int i = 0; i < nb; i++) {
+        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
+
+        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
+        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
+
+        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
+        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
+
+        vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
+        vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
+
+        vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl);
+        vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl);
+
+        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
+        vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
+        sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d);
+    }
+
+    *s = sumf;
 #else
     // scalar
     float sumf = 0.0;
@@ -2805,6 +2802,38 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
     }
 
     *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (int i = 0; i < nb; i++) {
+        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
+
+        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
+        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
+
+        vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
+        vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
+
+        vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
+        vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
+
+        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
+        vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
+        sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
 #else
     // scalar
     float sumf = 0.0;
@@ -3039,6 +3068,76 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
     }
 
     *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    uint32_t qh;
+
+    // These temp values are for masking and shift operations
+    uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+    uint32_t temp_2[16] = {0x1,   0x2,   0x4,   0x8,   0x10,   0x20,   0x40,   0x80,
+                         0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (int i = 0; i < nb; i++) {
+        memcpy(&qh, x[i].qh, sizeof(uint32_t));
+
+        // temporary registers
+        vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl);
+        vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl);
+        vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl);
+        vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl);
+
+        // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
+        vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl);
+        vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl);
+        vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
+
+        // ((qh & (1u << (j + 16))) >> (j + 12));
+        vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl);
+        vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl);
+
+        // narrowing
+        vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl);
+        vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
+
+        vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl);
+        vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
+
+        // load
+        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
+
+        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
+        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
+
+        vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
+        vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
+
+        vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
+        vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
+
+        vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
+        vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
+
+        vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl);
+        vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl);
+
+        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
+        vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
+        sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi;
+    }
+
+    *s = sumf;
 #else
     // scalar
     float sumf = 0.0;
@@ -3295,6 +3394,72 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
     }
 
     *s = hsum_float_8(acc) + summs;
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+
+    uint32_t qh;
+
+    // These temp values are for shift operations
+    uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+
+    size_t vl = __riscv_vsetvl_e8m1(qk/2);
+
+    for (int i = 0; i < nb; i++) {
+        memcpy(&qh, x[i].qh, sizeof(uint32_t));
+
+        // temporary registers
+        vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl);
+        vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl);
+
+        // load qh
+        vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl);
+
+        // ((qh >> (j +  0)) << 4) & 0x10;
+        vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl);
+        vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl);
+        vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl);
+
+        // ((qh >> (j + 12))     ) & 0x10;
+        vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl);
+        vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl);
+
+        // narrowing
+        vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl);
+        vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl);
+
+        vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl);
+        vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl);
+
+        // load
+        vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl);
+
+        vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
+        vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl);
+
+        vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl);
+        vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl);
+
+        vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl);
+        vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl);
+
+        vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a);
+        vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l);
+
+        vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl);
+        vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl);
+
+        vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl);
+
+        vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl);
+        vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(vs1);
+        sumi += __riscv_vmv_x_s_i32m1_i32(vs2);
+
+        sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s;
+    }
+
+    *s = sumf;
 #else
     // scalar
     float sumf = 0.0;
@@ -3406,6 +3571,26 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
     }
 
     *s = hsum_float_8(acc);
+#elif defined(__riscv_v_intrinsic)
+    float sumf = 0.0;
+    size_t vl = __riscv_vsetvl_e8m1(qk);
+
+    for (int i = 0; i < nb; i++) {
+        // load elements
+        vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl);
+        vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl);
+
+        vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl);
+
+        vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
+        vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);
+
+        int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum);
+
+        sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d));
+    }
+
+    *s = sumf;
 #else
     // scalar
     float sumf = 0.0;
@@ -4106,16 +4291,11 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
 }
 
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
-    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
-
-    // this should handle cases where the tensor is not contiguous in memory
-    // probaby just:
-    //
-    //     return tensor->ne[3]*tensor->nb[3]
-    //
-    // is enough, but just in case, adding the second part
-
-    return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type));
+    size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
+    for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+        nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+    }
+    return nbytes;
 }
 
 size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
@@ -4569,36 +4749,51 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         enum   ggml_type      type,
         int                   n_dims,
         const int64_t       * ne,
-        void                * data) {
+        struct ggml_tensor  * view_src,
+        size_t                view_offs) {
 
     assert(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
 
-    size_t data_size = 0;
+    // find the base tensor and absolute offset
+    if (view_src != NULL && view_src->view_src != NULL) {
+        view_offs += view_src->view_offs;
+        view_src   = view_src->view_src;
+    }
 
-    if (data == NULL && !ctx->no_alloc) {
-        data_size += ggml_type_size(type)*(ne[0]/ggml_blck_size(type));
-        for (int i = 1; i < n_dims; i++) {
-            data_size *= ne[i];
-        }
+    size_t data_size = ggml_type_size(type)*(ne[0]/ggml_blck_size(type));
+    for (int i = 1; i < n_dims; i++) {
+        data_size *= ne[i];
     }
 
-    if (ctx->scratch.data != NULL && data == NULL) {
-        // allocate tensor data in the scratch buffer
-        if (ctx->scratch.offs + data_size > ctx->scratch.size) {
-            GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
-                    __func__, ctx->scratch.offs + data_size, ctx->scratch.size);
-            assert(false);
-            return NULL;
-        }
+    GGML_ASSERT(view_src == NULL || data_size + view_offs <= ggml_nbytes(view_src));
+
+    void * data = view_src != NULL ? view_src->data : NULL;
+    if (data != NULL) {
+        data = (char *) data + view_offs;
+    }
 
-        data = (char * const) ctx->scratch.data + ctx->scratch.offs;
+    size_t obj_alloc_size = 0;
 
-        ctx->scratch.offs += data_size;
+    if (view_src == NULL && ctx->no_alloc == false) {
+        if (ctx->scratch.data != NULL) {
+            // allocate tensor data in the scratch buffer
+            if (ctx->scratch.offs + data_size > ctx->scratch.size) {
+                GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
+                        __func__, ctx->scratch.offs + data_size, ctx->scratch.size);
+                assert(false);
+                return NULL;
+            }
 
-        data_size = 0;
+            data = (char * const) ctx->scratch.data + ctx->scratch.offs;
+
+            ctx->scratch.offs += data_size;
+        } else {
+            // allocate tensor data in the context's memory pool
+            obj_alloc_size = data_size;
+        }
     }
 
-    struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + data_size);
+    struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size);
 
     // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here
 
@@ -4618,7 +4813,9 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.perf_runs    =*/ 0,
         /*.perf_cycles  =*/ 0,
         /*.perf_time_us =*/ 0,
-        /*.data         =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
+        /*.view_src     =*/ view_src,
+        /*.view_offs    =*/ view_offs,
+        /*.data         =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
         /*.name         =*/ { 0 },
         /*.extra        =*/ NULL,
         /*.padding      =*/ { 0 },
@@ -4642,28 +4839,12 @@ static struct ggml_tensor * ggml_new_tensor_impl(
     return result;
 }
 
-static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
-    GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings
-    assert(params_size <= GGML_MAX_OP_PARAMS);
-    memcpy(tensor->op_params, params, params_size);
-}
-
-static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
-    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
-    return ((const int32_t *)(tensor->op_params))[i];
-}
-
-static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
-    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
-    ((int32_t *)(tensor->op_params))[i] = value;
-}
-
 struct ggml_tensor * ggml_new_tensor(
         struct ggml_context * ctx,
         enum   ggml_type      type,
         int                   n_dims,
         const int64_t       * ne) {
-    return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL);
+    return ggml_new_tensor_impl(ctx, type, n_dims, ne, NULL, 0);
 }
 
 struct ggml_tensor * ggml_new_tensor_1d(
@@ -4728,7 +4909,23 @@ struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) {
 }
 
 struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) {
-    return ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL);
+    return ggml_new_tensor(ctx, src->type, src->n_dims, src->ne);
+}
+
+static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) {
+    GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings
+    assert(params_size <= GGML_MAX_OP_PARAMS);
+    memcpy(tensor->op_params, params, params_size);
+}
+
+static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+    return ((const int32_t *)(tensor->op_params))[i];
+}
+
+static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) {
+    assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t));
+    ((int32_t *)(tensor->op_params))[i] = value;
 }
 
 struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
@@ -5004,13 +5201,6 @@ struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * nam
     return tensor;
 }
 
-#ifdef __GNUC__
-#ifdef __MINGW32__
-__attribute__((gnu_format(printf, 2, 3)))
-#else
-__attribute__((format(printf, 2, 3)))
-#endif
-#endif
 struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...) {
     va_list args;
     va_start(args, fmt);
@@ -5021,14 +5211,13 @@ struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char *
 
 struct ggml_tensor * ggml_view_tensor(
         struct ggml_context * ctx,
-        const struct ggml_tensor * src) {
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data);
+        struct ggml_tensor  * src) {
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src, 0);
     ggml_format_name(result, "%s (view)", src->name);
 
-    result->nb[0] = src->nb[0];
-    result->nb[1] = src->nb[1];
-    result->nb[2] = src->nb[2];
-    result->nb[3] = src->nb[3];
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        result->nb[i] = src->nb[i];
+    }
 
     return result;
 }
@@ -5601,7 +5790,7 @@ struct ggml_tensor * ggml_repeat_back(
 
 // ggml_concat
 
-struct ggml_tensor* ggml_concat(
+struct ggml_tensor * ggml_concat(
     struct ggml_context* ctx,
     struct ggml_tensor* a,
     struct ggml_tensor* b) {
@@ -5868,7 +6057,8 @@ struct ggml_tensor * ggml_rms_norm_inplace(
 struct ggml_tensor * ggml_rms_norm_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
+        struct ggml_tensor  * b,
+        float  eps) {
     bool is_node = false;
 
     if (a->grad) {
@@ -5878,6 +6068,8 @@ struct ggml_tensor * ggml_rms_norm_back(
 
     struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
 
+    ggml_set_op_params(result, &eps, sizeof(eps));
+
     result->op   = GGML_OP_RMS_NORM_BACK;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
@@ -6207,7 +6399,7 @@ struct ggml_tensor * ggml_reshape(
         //GGML_ASSERT(false);
     }
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a->data);
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, b->n_dims, b->ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
@@ -6231,7 +6423,7 @@ struct ggml_tensor * ggml_reshape_1d(
     }
 
     const int64_t ne[1] = { ne0 };
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a->data);
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
@@ -6256,7 +6448,7 @@ struct ggml_tensor * ggml_reshape_2d(
     }
 
     const int64_t ne[2] = { ne0, ne1 };
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a->data);
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
@@ -6282,7 +6474,7 @@ struct ggml_tensor * ggml_reshape_3d(
     }
 
     const int64_t ne[3] = { ne0, ne1, ne2 };
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a->data);
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
@@ -6292,7 +6484,6 @@ struct ggml_tensor * ggml_reshape_3d(
     return result;
 }
 
-
 struct ggml_tensor * ggml_reshape_4d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -6310,7 +6501,7 @@ struct ggml_tensor * ggml_reshape_4d(
     }
 
     const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a->data);
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
     ggml_format_name(result, "%s (reshaped)", a->name);
 
     result->op   = GGML_OP_RESHAPE;
@@ -6320,46 +6511,40 @@ struct ggml_tensor * ggml_reshape_4d(
     return result;
 }
 
-// ggml_view_1d
-
-static struct ggml_tensor * ggml_view_tensor_offset(
+static struct ggml_tensor * ggml_view_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         int                   n_dims,
         const int64_t       * ne,
         size_t                offset) {
-    // don't calculate an offset from an unallocated tensor
-    void * data = NULL;
-    if (a->data != NULL) {
-        data = (char *) a->data + offset;
-    }
 
-    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, data);
+    bool is_node = false;
 
+    if (a->grad) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);
     ggml_format_name(result, "%s (view)", a->name);
 
     ggml_set_op_params(result, &offset, sizeof(offset));
 
+    result->op   = GGML_OP_VIEW;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
     return result;
 }
 
+// ggml_view_1d
+
 struct ggml_tensor * ggml_view_1d(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         int64_t               ne0,
         size_t                offset) {
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
-    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 1, &ne0, offset);
-
-    result->op   = GGML_OP_VIEW;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
+    struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset);
 
     return result;
 }
@@ -6374,24 +6559,14 @@ struct ggml_tensor * ggml_view_2d(
         size_t                nb1,
         size_t                offset) {
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
-    const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, 1, 1 };
+    const int64_t ne[2] = { ne0, ne1 };
 
-    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 2, ne, offset);
+    struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset);
 
     result->nb[1] = nb1;
     result->nb[2] = result->nb[1]*ne1;
     result->nb[3] = result->nb[2];
 
-    result->op   = GGML_OP_VIEW;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
     return result;
 }
 
@@ -6407,24 +6582,14 @@ struct ggml_tensor * ggml_view_3d(
         size_t                nb2,
         size_t                offset) {
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
-    const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, 1 };
+    const int64_t ne[3] = { ne0, ne1, ne2 };
 
-    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 3, ne, offset);
+    struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset);
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
     result->nb[3] = result->nb[2]*ne2;
 
-    result->op   = GGML_OP_VIEW;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
     return result;
 }
 
@@ -6442,24 +6607,14 @@ struct ggml_tensor * ggml_view_4d(
         size_t                nb3,
         size_t                offset) {
 
-    bool is_node = false;
-
-    if (a->grad) {
-        is_node = true;
-    }
-
-    const int64_t ne[GGML_MAX_DIMS] = { ne0, ne1, ne2, ne3 };
+    const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
 
-    struct ggml_tensor * result = ggml_view_tensor_offset(ctx, a, 4, ne, offset);
+    struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset);
 
     result->nb[1] = nb1;
     result->nb[2] = nb2;
     result->nb[3] = nb3;
 
-    result->op   = GGML_OP_VIEW;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
     return result;
 }
 
@@ -6646,7 +6801,7 @@ static struct ggml_tensor * ggml_diag_mask_inf_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[] = { n_past, inplace ? 1 : 0 };
+    int32_t params[] = { n_past };
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_DIAG_MASK_INF;
@@ -6663,7 +6818,6 @@ struct ggml_tensor * ggml_diag_mask_inf(
     return ggml_diag_mask_inf_impl(ctx, a, n_past, false);
 }
 
-
 struct ggml_tensor * ggml_diag_mask_inf_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
@@ -6686,7 +6840,7 @@ static struct ggml_tensor * ggml_diag_mask_zero_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    int32_t params[] = { n_past, inplace ? 1 : 0 };
+    int32_t params[] = { n_past };
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_DIAG_MASK_ZERO;
@@ -7475,6 +7629,8 @@ static struct ggml_tensor * ggml_add_rel_pos_impl(
     }
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+    ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
+
     result->op   = GGML_OP_ADD_REL_POS;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
@@ -9452,6 +9608,8 @@ static void ggml_compute_forward_div_f32(
 
 
 #ifdef GGML_USE_ACCELERATE
+            UNUSED(ggml_vec_div_f32);
+
             vDSP_vdiv(
                     (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
                     (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
@@ -10758,7 +10916,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
 
     GGML_TENSOR_BINARY_OP_LOCALS;
 
-    const float eps = 1e-6f; // TODO: make this a parameter
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -11936,8 +12095,8 @@ static void ggml_compute_forward_diag_mask_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int  n_past  =       ((int32_t *) dst->op_params)[0];
-    const bool inplace = (bool)((int32_t *) dst->op_params)[1];
+    const int  n_past  = ((int32_t *) dst->op_params)[0];
+    const bool inplace = src0->data == dst->data;
 
     GGML_ASSERT(n_past >= 0);
 
@@ -12148,6 +12307,7 @@ static void ggml_compute_forward_soft_max_back_f32(
         // dx = J * dy
         // dxk = sum_i(Jki * dyi)
         // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
+        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
         // dxk = sum_i(-yk*yi * dyi) + yk*dyk
         // dxk = -yk * sum_i(yi * dyi) + yk*dyk
         // dxk = -yk * dot(y, dy) + yk*dyk
@@ -13938,7 +14098,7 @@ static void ggml_compute_forward_flash_attn_f32(
                 vvexpf(S, S, &Mup);
                 ggml_vec_sum_f32(Mup, &sum, S);
 #else
-                uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
+                uint16_t   scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
                 ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 
                 for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
@@ -13948,9 +14108,13 @@ static void ggml_compute_forward_flash_attn_f32(
                         if (SS[j] == -INFINITY) {
                             SS[j] = 0.0f;
                         } else {
+#ifndef GGML_FLASH_ATTN_EXP_FP16
+                            const float val = expf(SS[j] - max);
+#else
                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
                             memcpy(&scvt[j], &s, sizeof(uint16_t));
                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+#endif
                             sump[j] += (ggml_float)val;
                             SS[j] = val;
                         }
@@ -14528,7 +14692,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
                     vvexpf(SM, SM, &Mup);
                     ggml_vec_sum_f32(Mup, &sum, SM);
 #else
-                    uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
+                    uint16_t   scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
                     ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 
                     for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
@@ -14539,9 +14703,13 @@ static void ggml_compute_forward_flash_attn_back_f32(
                             if (SR[j] == -INFINITY) {
                                 SW[j] = 0.0f;
                             } else {
+#ifndef GGML_FLASH_ATTN_EXP_FP16
+                                const float val = expf(SR[j] - max);
+#else
                                 ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
                                 memcpy(&scvt[j], &s, sizeof(uint16_t));
                                 const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
+#endif
                                 sump[j] += (ggml_float)val;
                                 SW[j] = val;
                             }
@@ -14987,11 +15155,8 @@ static void ggml_compute_forward_add_rel_pos_f32(
         const struct ggml_tensor * src1,
         const struct ggml_tensor * src2,
         struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
-    GGML_ASSERT(src0->nb[0] == dst->nb[0] && src0->nb[1] == dst->nb[1]
-             && src0->nb[2] == dst->nb[2] && src0->nb[3] == dst->nb[3]);
 
-    const bool inplace = dst->data == src0->data;
+    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
     if (!inplace && params->type == GGML_TASK_INIT) {
         memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
         return;
@@ -15282,6 +15447,8 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
     const int nc = src0->ne[0];
     const int nr = ggml_nrows(src0);
 
+    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
+
     if (params->type == GGML_TASK_INIT) {
         if (ith == 0) {
             memset(sums, 0, sizeof(float) * (nth + nth * nc));
@@ -15293,7 +15460,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
         if (ith == 0) {
             float * dp = (float *) dst->data;
             ggml_vec_sum_f32(nth, dp, sums);
-            dp[0] *= -1.0f;
+            dp[0] *= -1.0f / (float) nr;
         }
         return;
     }
@@ -15310,7 +15477,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
     for (int i1 = ir0; i1 < ir1; i1++) {
         float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
         float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
-        float * st = (float *) params->wdata + nth + ith*nc;
+        float * st = ((float *) params->wdata) + nth + ith*nc;
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
@@ -15325,15 +15492,19 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
             float max = -INFINITY;
             ggml_vec_max_f32(nc, &max, s0);
 
-            uint16_t scvt;
+            uint16_t scvt; UNUSED(scvt);
             for (int i = 0; i < nc; i++) {
                 if (s0[i] == -INFINITY) {
                     st[i] = 0.0f;
                 } else {
-                    // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
+#ifndef GGML_CROSS_ENTROPY_EXP_FP16
+                    const float s = s0[i] - max;
+                    const float val = expf(s);
+#else
                     ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
                     memcpy(&scvt, &s, sizeof(scvt));
                     const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+#endif
                     sum += (ggml_float)val;
                     st[i] = val;
                 }
@@ -15349,7 +15520,9 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
         ggml_vec_log_f32(nc, st, st);
         ggml_vec_mul_f32(nc, st, st, s1);
 
-        ggml_vec_sum_f32(nc, sums + ith, st);
+        float st_sum = 0;
+        ggml_vec_sum_f32(nc, &st_sum, st);
+        sums[ith] += st_sum;
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
@@ -15399,7 +15572,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         return;
     }
 
-    const float eps = 1e-9f;
+    const double eps = 1e-9;
 
     // TODO: handle transposed/permuted matrices
     const int64_t nc = src0->ne[0];
@@ -15418,7 +15591,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
         float * s0  = (float *)((char *) src0->data + i1*src0->nb[1]);
         float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
-        float * sm  = (float *) params->wdata + ith*nc;
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
@@ -15427,54 +15599,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
             assert(!isnan(s1[i]));
         }
 #endif
-        // step by step explanation:
-        {
-            //float * sums = (float *) params->wdata;
-
-            // forward pass with annotated gradients from backward pass
-            // (built by going in reverse operation order, adding to gradients of current operation args)
-            // st0 = exp(s0-max(s0))                                                       grad[st0] = grad[st1]*(1.0 - eps)/sum
-                                                          // from softmax_back:            grad[s0]  = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
-            // ggml_vec_scale_f32(nc, st, sum);           // st1 = st0*/sum = softmax(s0)  grad[st1] = grad[st2]*(1.0 - eps)
-            // ggml_vec_scale_f32(nc, st, (1.0f - eps));  // st2 = st1*(1.0 - eps)         grad[st2] = grad[st3]
-            // ggml_vec_add1_f32(nc, st, st, eps);        // st3 = st2 + eps               grad[st3] = grad[st4]/st3
-            // ggml_vec_log_f32(nc, st, st);              // st4 = log(st3)                grad[st4] = grad[st5] * s1
-            // ggml_vec_mul_f32(nc, st, st, s1);          // st5 = st4 * s1                grad[st5] = grad[sums[ith]]
-            // ggml_vec_sum_f32(nc, sums + ith, st);      // sums[ith] = st5               grad[sums[ith]] = grad[cross_entropy_loss] = -grad[cel]
-
-            // substitute into grad[st1], because we can reuse softmax_back from this point on
-            // grad[st1] = -grad[cel]*s1*(1.0 - eps)/(eps + softmax(s0)*(1.0 - eps))
-            // postorder:
-            // grad[st1] := softmax(s0)
-            // grad[st1] := grad[st1]*(1.0 - eps)
-            // grad[st1] := grad[st1] + eps
-            // grad[st1] := s1 / grad[st1]
-            // grad[st1] := grad[st1]*(1.0-eps)*-grad[cel]
-
-            // src0 gradients by going through softmax_back
-            // grad[s0] = st1_k * (grad[st1]_k - dot(st1, grad[st1]))
-            //   from softmax_back:
-            //   dxk = yk * (dyk - dot(y, dy))
-            //   dot_y_dy := dot(y, dy)
-            //   dx := dy
-            //   dx := dx - dot_y_dy
-            //   dx := dx * y
-            //   postorder:
-            //   dot_st1_dst1 := dot(st1, grad[st1])
-            //   grad[s0] := grad[st1]
-            //   grad[s0] := grad[s0] - dot_st1_dst1
-            //   grad[s0] := grad[s0] * st1
-
-            // prepend postorder from grad[st1] directly using grad[s0] as memory location, as we will grad[s0] := grad[st1]
-            // sm           := softmax(s0)
-            // grad[s0]     := sm*(1.0 - eps)
-            // grad[s0]     := grad[s0] + eps
-            // grad[s0]     := s1 / grad[s0]
-            // grad[s0]     := grad[s0]*(1.0-eps)*-grad[cel]
-            // dot_st1_dst1 := dot(sm, grad[s0])
-            // grad[s0]     := grad[s0] - dot_st1_dst1
-            // grad[s0]     := grad[s0] * sm
-        }
 
         // soft_max
         ggml_float sum = 0.0;
@@ -15482,39 +15606,37 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
             float max = -INFINITY;
             ggml_vec_max_f32(nc, &max, s0);
 
-            uint16_t scvt;
+            uint16_t scvt; UNUSED(scvt);
             for (int i = 0; i < nc; i++) {
                 if (s0[i] == -INFINITY) {
-                    sm[i] = 0.0f;
+                    ds0[i] = 0.0f;
                 } else {
-                    // const float val = (s0[i] == -INFINITY) ? 0.0 : exp(s0[i] - max);
+#ifndef GGML_CROSS_ENTROPY_EXP_FP16
+                    const float s = s0[i] - max;
+                    const float val = expf(s);
+#else
                     ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
                     memcpy(&scvt, &s, sizeof(scvt));
                     const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
+#endif
                     sum += (ggml_float)val;
-                    sm[i] = val;
+                    ds0[i] = val;
                 }
             }
 
             assert(sum > 0.0);
-            sum = 1.0/sum;
+            sum = (1.0 - eps)/sum;
         }
 
-        float dot_st1_dst1 = 0;
-        ggml_vec_scale_f32(nc, sm, sum);
-        ggml_vec_cpy_f32  (nc, ds0, sm);
-        ggml_vec_scale_f32(nc, ds0, (1.0f - eps));
-        ggml_vec_add1_f32 (nc, ds0, ds0, eps);
-        ggml_vec_div_f32  (nc, ds0, s1, ds0);
-        ggml_vec_scale_f32(nc, ds0, -(1.0f - eps)*d[0]);
-        ggml_vec_dot_f32  (nc, &dot_st1_dst1, sm, ds0);
-        ggml_vec_acc1_f32 (nc, ds0, -dot_st1_dst1);
-        ggml_vec_mul_f32  (nc, ds0, ds0, sm);
+        // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
+        ggml_vec_scale_f32(nc, ds0, sum);
+        ggml_vec_add1_f32(nc, ds0, ds0, eps);
+        ggml_vec_sub_f32(nc, ds0, ds0, s1);
+        ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
+
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
-            assert(!isnan(sm[i]));
-            assert(!isinf(sm[i]));
             assert(!isnan(ds0[i]));
             assert(!isinf(ds0[i]));
         }
@@ -16069,9 +16191,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
             {
                 // necessary for llama
                 if (src0->grad) {
+                    float eps;
+                    memcpy(&eps, tensor->op_params, sizeof(float));
+
                     src0->grad = ggml_add_impl(ctx,
                             src0->grad,
-                            ggml_rms_norm_back(ctx, src0, tensor->grad),
+                            ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
                             inplace);
                 }
             } break;
@@ -16839,9 +16964,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
     return result;
 }
 
-struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
-    struct ggml_cgraph result = *gf;
-
+void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
     GGML_ASSERT(gf->n_nodes > 0);
 
     // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
@@ -16865,15 +16988,19 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
         }
     }
 
-    for (int i = gf->n_nodes - 1; i >= 0; i--) {
+    for (int i = 0; i < gf->n_nodes; i++) {
         struct ggml_tensor * node = gf->nodes[i];
 
         if (node->is_param) {
             GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
-            ggml_build_forward_expand(&result, node->grad);
+            ggml_build_forward_expand(gb, node->grad);
         }
     }
+}
 
+struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
+    struct ggml_cgraph result = *gf;
+    ggml_build_backward_expand(ctx, gf, &result, keep);
     return result;
 }
 
@@ -17549,10 +17676,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
             case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
                 {
                     n_tasks = n_threads;
-
-                    size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks;
-
-                    work_size = MAX(work_size, cur);
                 } break;
             case GGML_OP_NONE:
                 {
@@ -18430,14 +18553,16 @@ static enum ggml_opt_result ggml_opt_adam(
         struct ggml_opt_params params,
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
-        struct ggml_cgraph * gb) {
+        struct ggml_cgraph * gb,
+        ggml_opt_callback callback,
+        void * callback_data) {
     GGML_ASSERT(ggml_is_scalar(f));
 
     // these will store the parameters we want to optimize
     struct ggml_tensor * ps[GGML_MAX_PARAMS];
 
     int np = 0;
-    int nx = 0;
+    int64_t nx = 0;
     for (int i = 0; i < gf->n_nodes; ++i) {
         if (gf->nodes[i]->is_param) {
             GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
@@ -18456,31 +18581,32 @@ static enum ggml_opt_result ggml_opt_adam(
     }
 
     // constants
-    const float sched = params.adam.sched;
-    const float decay = params.adam.decay * sched;
-    const float alpha = params.adam.alpha * sched;
+    float sched = params.adam.sched;
+    const float alpha = params.adam.alpha;
+    const float decay = params.adam.decay * alpha;
     const float beta1 = params.adam.beta1;
     const float beta2 = params.adam.beta2;
     const float eps   = params.adam.eps;
+    const float gclip = params.adam.gclip;
+    const int decay_min_ndim = params.adam.decay_min_ndim;
 
-    float * x  = opt->adam.x->data;  // view of the parameters
-    float * g1 = opt->adam.g1->data; // gradient
-    float * g2 = opt->adam.g2->data; // gradient squared
     float * m  = opt->adam.m->data;  // first moment
     float * v  = opt->adam.v->data;  // second moment
-    float * mh = opt->adam.mh->data; // first moment hat
-    float * vh = opt->adam.vh->data; // second moment hat
 
     float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
 
-    // update view
-    ggml_opt_get_params(np, ps, x);
+    if (callback) {
+        callback(callback_data, &sched);
+    }
 
     // compute the function value
     ggml_graph_reset  (gf);
     ggml_set_f32      (f->grad, 1.0f);
 
-    ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
+    struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
+    cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
+    ggml_graph_compute(gb, &cplan);
 
     opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
     opt->adam.fx_best = opt->adam.fx_prev;
@@ -18488,6 +18614,9 @@ static enum ggml_opt_result ggml_opt_adam(
         pf[opt->iter % params.past] = opt->adam.fx_prev;
     }
 
+    opt->loss_before = opt->adam.fx_prev;
+    opt->loss_after  = opt->adam.fx_prev;
+
     // initialize
     if (opt->just_initialized) {
         opt->adam.n_no_improvement = 0;
@@ -18520,50 +18649,55 @@ static enum ggml_opt_result ggml_opt_adam(
         UNUSED(t_start_cpu);
 
         {
-            // update the gradient
-            ggml_opt_get_grad(np, ps, g1);
-
-            // m_t = beta1*m_t-1 + (1 - beta1)*g_t
-            ggml_vec_scale_f32(nx, m, beta1);
-            ggml_vec_mad_f32  (nx, m, g1, 1.0f - beta1);
-
-            // g2 = g1^2
-            ggml_vec_sqr_f32  (nx, g2, g1);
-
-            // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2
-            ggml_vec_scale_f32(nx, v, beta2);
-            ggml_vec_mad_f32  (nx, v, g2, 1.0f - beta2);
-
-            // m^hat = m_t / (1 - beta1^t)
-            // v^hat = v_t / (1 - beta2^t)
-            // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
-            // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
-            // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
-            // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
-            // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
-            ggml_vec_cpy_f32  (nx, mh, m);
-            ggml_vec_cpy_f32  (nx, vh, v);
-
-            ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
-            ggml_vec_scale_f32(nx, vh,  1.0f/(1.0f - powf(beta2, opt->iter)));
-
-            ggml_vec_sqrt_f32 (nx, vh, vh);
-            ggml_vec_acc1_f32 (nx, vh, eps);
-
-            ggml_vec_div_f32  (nx, mh, mh, vh);
-            ggml_vec_scale_f32(nx, x,  1.0f - decay);
-            ggml_vec_sub_f32  (nx, x,  x,  mh);
+            float gnorm = 1.0f;
+            if (gclip > 0.0f) {
+                // gradient clipping
+                ggml_float sum = 0.0;
+                for (int p = 0; p < np; ++p) {
+                    const int64_t ne = ggml_nelements(ps[p]);
+                    for (int64_t j = 0; j < ne; ++j) {
+                        float g = ggml_get_f32_1d(ps[p]->grad, j);
+                        sum += (ggml_float)(g*g);
+                    }
+                }
+                ggml_float norm = sqrt(sum);
+                if (norm > (ggml_float) gclip) {
+                    gnorm = (float) ((ggml_float) gclip / norm);
+                }
+            }
+            const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter));
+            const float beta2h =        1.0f/(1.0f - powf(beta2, opt->iter));
+            int64_t i = 0;
+            for (int p = 0; p < np; ++p) {
+                const int64_t ne = ggml_nelements(ps[p]);
+                const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
+                for (int64_t j = 0; j < ne; ++j) {
+                    float x = ggml_get_f32_1d(ps[p], j);
+                    float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
+                    m[i] = m[i]*beta1 +   g*(1.0f - beta1);
+                    v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
+                    float mh = m[i]*beta1h;
+                    float vh = v[i]*beta2h;
+                    vh = sqrtf(vh) + eps;
+                    x  = x*(1.0f - p_decay) - mh/vh;
+                    ggml_set_f32_1d(ps[p], j, x);
+                    ++i;
+                }
+            }
+        }
 
-            // update the parameters
-            ggml_opt_set_params(np, ps, x);
+        if (callback) {
+            callback(callback_data, &sched);
         }
 
         ggml_graph_reset  (gf);
         ggml_set_f32      (f->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
+        ggml_graph_compute(gb, &cplan);
 
         const float fx = ggml_get_f32_1d(f, 0);
+        opt->loss_after = fx;
+
 
         // check convergence
         if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
@@ -18632,7 +18766,6 @@ struct ggml_lbfgs_iteration_data {
 };
 
 static enum ggml_opt_result linesearch_backtracking(
-        struct ggml_context * ctx,
         const struct ggml_opt_params * params,
         int nx,
         float * x,
@@ -18644,8 +18777,11 @@ static enum ggml_opt_result linesearch_backtracking(
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
         struct ggml_cgraph * gb,
+        struct ggml_cplan  * cplan,
         const int np,
-        struct ggml_tensor * ps[]) {
+        struct ggml_tensor * ps[],
+        ggml_opt_callback callback,
+        void * callback_data) {
     int count = 0;
 
     float width  = 0.0f;
@@ -18674,6 +18810,12 @@ static enum ggml_opt_result linesearch_backtracking(
     dgtest = params->lbfgs.ftol*dginit;
 
     while (true) {
+        if (callback) {
+            // LBFG-S does not support learning rate -> ignore learning schedule
+            float sched = 0;
+            callback(callback_data, &sched);
+        }
+
         ggml_vec_cpy_f32(nx, x, xp);
         ggml_vec_mad_f32(nx, x, d, *step);
 
@@ -18684,7 +18826,7 @@ static enum ggml_opt_result linesearch_backtracking(
             ggml_graph_reset  (gf);
             ggml_set_f32      (f->grad, 1.0f);
 
-            ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
+            ggml_graph_compute(gb, cplan);
 
             ggml_opt_get_grad(np, ps, g);
 
@@ -18718,7 +18860,6 @@ static enum ggml_opt_result linesearch_backtracking(
                     // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
                     return count;
                 }
-                return count;
             }
         }
 
@@ -18744,7 +18885,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         struct ggml_opt_params params,
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
-        struct ggml_cgraph * gb) {
+        struct ggml_cgraph * gb,
+        ggml_opt_callback callback,
+        void * callback_data) {
     if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
         params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
         if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) {
@@ -18776,6 +18919,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         opt->iter = iter;
     }
 
+    struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
+    struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
+    cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
+
     float * x  = opt->lbfgs.x->data;  // current parameters
     float * xp = opt->lbfgs.xp->data; // previous parameters
     float * g  = opt->lbfgs.g->data;  // current gradient
@@ -18797,6 +18944,12 @@ static enum ggml_opt_result ggml_opt_lbfgs(
     float * lm_s     = opt->lbfgs.lms->data;
     float * lm_y     = opt->lbfgs.lmy->data;
 
+    if (callback) {
+        // LBFG-S does not support learning rate -> ignore learning schedule
+        float sched = 0;
+        callback(callback_data, &sched);
+    }
+
     // evaluate the function value and its gradient
     {
         ggml_opt_set_params(np, ps, x);
@@ -18804,11 +18957,14 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         ggml_graph_reset  (gf);
         ggml_set_f32      (f->grad, 1.0f);
 
-        ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
+        ggml_graph_compute(gb, &cplan);
 
         ggml_opt_get_grad(np, ps, g);
 
         fx = ggml_get_f32_1d(f, 0);
+
+        opt->loss_before = fx;
+        opt->loss_after  = fx;
     }
 
     // search direction = -gradient
@@ -18863,7 +19019,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         ggml_vec_cpy_f32(nx, xp, x);
         ggml_vec_cpy_f32(nx, gp, g);
 
-        ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
+        ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data);
 
         if (ls < 0) {
             // linesearch failed - go back to the previous point and return
@@ -18873,6 +19029,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
             return ls;
         }
 
+        opt->loss_after = fx;
+
         ggml_vec_norm_f32(nx, &xnorm, x);
         ggml_vec_norm_f32(nx, &gnorm, g);
 
@@ -18930,7 +19088,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         //     ys = y^t \cdot s    -> 1 / \rho.
         //     yy = y^t \cdot y.
         //
-        ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
+        ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]);
         ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
 
         lm_ys[end[0]] = ys;
@@ -18993,13 +19151,15 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
                     .adam = {
                         .n_iter = 10000,
                         .sched  = 1.000f,
-                        .decay  = 0.001f,
+                        .decay  = 0.0f,
+                        .decay_min_ndim = 2,
                         .alpha  = 0.001f,
                         .beta1  = 0.9f,
                         .beta2  = 0.999f,
                         .eps    = 1e-8f,
                         .eps_f  = 1e-5f,
                         .eps_g  = 1e-3f,
+                        .gclip  = 0.0f,
                     },
                 };
             } break;
@@ -19049,23 +19209,13 @@ GGML_API void ggml_opt_init(
     switch (opt->params.type) {
         case GGML_OPT_ADAM:
             {
-                opt->adam.x  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
-                opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
-                opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
                 opt->adam.m  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
                 opt->adam.v  = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
-                opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
-                opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
                 opt->adam.pf = params.past > 0
                     ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
                     : NULL;
-                ggml_set_zero(opt->adam.x);
-                ggml_set_zero(opt->adam.g1);
-                ggml_set_zero(opt->adam.g2);
                 ggml_set_zero(opt->adam.m);
                 ggml_set_zero(opt->adam.v);
-                ggml_set_zero(opt->adam.mh);
-                ggml_set_zero(opt->adam.vh);
                 if (opt->adam.pf) {
                     ggml_set_zero(opt->adam.pf);
                 }
@@ -19149,7 +19299,7 @@ enum ggml_opt_result ggml_opt_resume(
     *gf = ggml_build_forward (f);
     *gb = ggml_build_backward(ctx, gf, true);
 
-    return ggml_opt_resume_g(ctx, opt, f, gf, gb);
+    return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
 }
 
 enum ggml_opt_result ggml_opt_resume_g(
@@ -19157,7 +19307,9 @@ enum ggml_opt_result ggml_opt_resume_g(
         struct ggml_opt_context * opt,
         struct ggml_tensor * f,
         struct ggml_cgraph * gf,
-        struct ggml_cgraph * gb) {
+        struct ggml_cgraph * gb,
+        ggml_opt_callback callback,
+        void * callback_data) {
 
     // build forward + backward compute graphs
     enum ggml_opt_result result = GGML_OPT_OK;
@@ -19165,11 +19317,11 @@ enum ggml_opt_result ggml_opt_resume_g(
     switch (opt->params.type) {
         case GGML_OPT_ADAM:
             {
-                result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
+                result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
             } break;
         case GGML_OPT_LBFGS:
             {
-                result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
+                result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
             } break;
     }
 
@@ -19624,7 +19776,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
 
     // read the kv pairs
     {
-        ctx->kv = GGML_ALIGNED_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv));
+        ctx->kv = malloc(ctx->header.n_kv * sizeof(struct gguf_kv));
 
         for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
             struct gguf_kv * kv = &ctx->kv[i];
@@ -19707,7 +19859,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
 
     // read the tensor infos
     {
-        ctx->infos = GGML_ALIGNED_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info));
+        ctx->infos = malloc(ctx->header.n_tensors * sizeof(struct gguf_tensor_info));
 
         for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
             struct gguf_tensor_info * info = &ctx->infos[i];
@@ -19908,7 +20060,7 @@ void gguf_free(struct gguf_context * ctx) {
             }
         }
 
-        GGML_ALIGNED_FREE(ctx->kv);
+        free(ctx->kv);
     }
 
     if (ctx->infos) {
@@ -19920,7 +20072,7 @@ void gguf_free(struct gguf_context * ctx) {
             }
         }
 
-        GGML_ALIGNED_FREE(ctx->infos);
+        free(ctx->infos);
     }
 
     GGML_ALIGNED_FREE(ctx);
diff --git a/ggml.h b/ggml.h
index 4ef3d525371fe6ae2621eabfc9e3cd349da4c8b8..c936823d661404434484bebf1c88ed0d2c822e48 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -479,6 +479,9 @@ extern "C" {
         int64_t perf_cycles;
         int64_t perf_time_us;
 
+        struct ggml_tensor * view_src;
+        size_t               view_offs;
+
         void * data;
 
         char name[GGML_MAX_NAME];
@@ -661,7 +664,7 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
 
     GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
-    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
+    GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
 
     GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
 
@@ -952,11 +955,11 @@ extern "C" {
 
     // a - x
     // b - dy
-    // TODO: update with configurable eps
     GGML_API struct ggml_tensor * ggml_rms_norm_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            float                 eps);
 
     // A: n columns, m rows
     // B: n columns, p rows  (i.e. we transpose it internally)
@@ -1612,7 +1615,8 @@ extern "C" {
             struct ggml_tensor  * tensor);
 
 
-    GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+    GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
+    GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
 
     GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
     GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
@@ -1677,6 +1681,8 @@ extern "C" {
         GGML_LINESEARCH_INVALID_PARAMETERS,
     };
 
+    typedef void (*ggml_opt_callback)(void * data, float * sched);
+
     // optimization parameters
     //
     //   see ggml.c (ggml_opt_default_params) for default values
@@ -1712,12 +1718,14 @@ extern "C" {
 
             float sched; // schedule multiplier (fixed, decay or warmup)
             float decay; // weight decay for AdamW, use 0.0f to disable
+            int   decay_min_ndim; // minimum number of tensor dimension to apply weight decay
             float alpha; // learning rate
             float beta1;
             float beta2;
             float eps;   // epsilon for numerical stability
             float eps_f; // epsilon for convergence test
             float eps_g; // epsilon for convergence test
+            float gclip; // gradient clipping
         } adam;
 
         // LBFGS parameters
@@ -1745,14 +1753,12 @@ extern "C" {
 
         bool just_initialized;
 
+        float loss_before;
+        float loss_after;
+
         struct {
-            struct ggml_tensor * x;  // view of the parameters
-            struct ggml_tensor * g1; // gradient
-            struct ggml_tensor * g2; // gradient squared
             struct ggml_tensor * m;  // first moment
             struct ggml_tensor * v;  // second moment
-            struct ggml_tensor * mh; // first moment hat
-            struct ggml_tensor * vh; // second moment hat
             struct ggml_tensor * pf; // past function values
             float fx_best;
             float fx_prev;
@@ -1789,10 +1795,10 @@ extern "C" {
 
     // initialize optimizer context
     GGML_API void ggml_opt_init(
-            struct ggml_context * ctx,
+            struct ggml_context     * ctx,
             struct ggml_opt_context * opt,
-            struct ggml_opt_params params,
-            int64_t nx);
+            struct ggml_opt_params    params,
+            int64_t                   nx);
 
     // continue optimizing the function defined by the tensor f
     GGML_API enum ggml_opt_result ggml_opt_resume(
@@ -1806,7 +1812,9 @@ extern "C" {
             struct ggml_opt_context * opt,
             struct ggml_tensor * f,
             struct ggml_cgraph * gf,
-            struct ggml_cgraph * gb);
+            struct ggml_cgraph * gb,
+            ggml_opt_callback callback,
+            void * callback_data);
 
     //
     // quantization