void * extra; // extra things e.g. for ggml-cuda.cu
- char padding[8];
+ char padding[4];
};
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
void * abort_callback_data;
};
+ // next prime after GGML_MAX_NODES
+ // #define GGML_GRAPH_HASHTABLE_SIZE 4099
+ // next prime after GGML_MAX_NODES * 2 (nodes + leafs)
+ #define GGML_GRAPH_HASHTABLE_SIZE 8273
+
// computation graph
struct ggml_cgraph {
int n_nodes;
struct ggml_tensor * grads[GGML_MAX_NODES];
struct ggml_tensor * leafs[GGML_MAX_NODES];
+ void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
+
// performance
int perf_runs;
int64_t perf_cycles;
GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
- struct ggml_tensor * a);
+ struct ggml_tensor * a,
+ float eps);
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
- struct ggml_tensor * a);
+ struct ggml_tensor * a,
+ float eps);
// a - x
// b - dy
+ // TODO: update with configurable eps
GGML_API struct ggml_tensor * ggml_rms_norm_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
}
}
-static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
+static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
- const float eps = 1e-6f;
-
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += WARP_SIZE) {
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
+ uint16_t q16[8];
+ const uint8_t * q4 = (const uint8_t *)q16;
+
for (int i = ix; i < num_blocks_per_row; i += 2) {
const uint8_t * ql1 = x[i].qs + q_offset;
- const uint8_t * ql2 = ql1 + 64;
const uint8_t * qh = x[i].qh + l0;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
float4 sum = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
+ const uint16_t * q1 = (const uint16_t *)ql1;
+ const uint16_t * q2 = q1 + 32;
+ q16[0] = q1[0] & 0x0f0f;
+ q16[1] = q1[8] & 0x0f0f;
+ q16[2] = (q1[0] >> 4) & 0x0f0f;
+ q16[3] = (q1[8] >> 4) & 0x0f0f;
+ q16[4] = q2[0] & 0x0f0f;
+ q16[5] = q2[8] & 0x0f0f;
+ q16[6] = (q2[0] >> 4) & 0x0f0f;
+ q16[7] = (q2[8] >> 4) & 0x0f0f;
for (int l = 0; l < n; ++l) {
- sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
- + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
- sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
- + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
- sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
- + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
- sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
- + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
+ sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+ + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
+ sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+ + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
+ sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
+ sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
- const int bq8_offset = QR4_K * (iqs / QI8_1); // 0, 2, 4, 6
-
float sumf_d = 0.0f;
float sumf_m = 0.0f;
+#ifndef GGML_QKK_64
+
+ // iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
+ const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
+
const float d = bq4_K->d;
const float dmin = bq4_K->dmin;
- const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
+ // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+ // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+ // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+ // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+ const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * (iqs%4));
+ const int v1 = q4[0];
+ const int v2 = q4[4];
const uint16_t * scales = (const uint16_t *)bq4_K->scales;
uint16_t aux[2];
for (int i = 0; i < QR4_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
+ const int * q8 = (const int *)bq8i->qs + (iqs%4);
+ const int ui1 = q8[0];
+ const int ui2 = q8[4];
- const int vi = (v >> (4*i)) & 0x0F0F0F0F;
+ const int vi1 = (v1 >> (4*i)) & 0x0F0F0F0F;
+ const int vi2 = (v2 >> (4*i)) & 0x0F0F0F0F;
- sumf_d += d8i * (__dp4a(vi, ui, 0) * sc[i]); // SIMD dot product
- sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m[i]); // multiply constant part of q4_K with sum of q8_1 values
+ const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
+ const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+
+ sumf_d += d8i * (dot1 * sc[i]);
+ sumf_m += d8i * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
}
return d*sumf_d - dmin*sumf_m;
+
+#else
+
+ uint16_t aux16[2];
+ const uint8_t * s = (const uint8_t *)aux16;
+
+ const uint16_t * a = (const uint16_t *)bq4_K->scales;
+ aux16[0] = a[0] & 0x0f0f;
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
+
+ const float dall = bq4_K->d[0];
+ const float dmin = bq4_K->d[1];
+
+ const float d8_1 = bq8_1[0].d;
+ const float d8_2 = bq8_1[1].d;
+
+ const int ui1 = *((const int *)bq8_1[0].qs + iqs);
+ const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
+ const int ui3 = *((const int *)bq8_1[1].qs + iqs);
+ const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+
+ const int * q4 = (const int *)bq4_K->qs + iqs;
+ const int v1 = q4[0];
+ const int v2 = q4[4];
+
+ const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
+ const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
+ const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
+ const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
+
+ sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
+ sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
+
+ return dall * sumf_d - dmin * sumf_m;
+
+#endif
+
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
- const int bq8_offset = QR5_K * (iqs / QI8_1);
+#ifndef GGML_QKK_64
+
+ const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
+ const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
+ const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
float sumf_d = 0.0f;
float sumf_m = 0.0f;
const float d = bq5_K->d;
const float dmin = bq5_K->dmin;
- const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]);
+ const int vl1 = ql[0];
+ const int vl2 = ql[4];
- const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset;
+ const int vh1 = qh[0] >> bq8_offset;
+ const int vh2 = qh[4] >> bq8_offset;
- for (int i = 0; i < QR5_K; ++i) {
- const int isc = bq8_offset + i;
+ const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
- uint8_t sc, m;
- get_scale_min_k4(isc, bq5_K->scales, sc, m);
+ for (int i = 0; i < QR5_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
- const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
+ const int * q8 = (const int *)bq8i->qs + (iqs%4);
+ const int ui1 = q8[0];
+ const int ui2 = q8[4];
- const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+ const int vil1 = (vl1 >> (4*i)) & 0x0F0F0F0F;
+ const int vil2 = (vl2 >> (4*i)) & 0x0F0F0F0F;
+
+ const int vih1 = ((vh1 >> i) << 4) & 0x10101010;
+ const int vih2 = ((vh2 >> i) << 4) & 0x10101010;
+
+ const int vi1 = vil1 | vih1;
+ const int vi2 = vil2 | vih2;
- const int vih = ((vh >> i) << 4) & 0x10101010;
+ const int dot1 = __dp4a(vi2, ui2, __dp4a(vi1, ui1, 0)); // SIMD dot product
+ const int dot2 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
- const int vi = vil | vih;
+ sumf_d += d8i * (dot1 * sc[i]);
+ sumf_m += d8i * (dot2 * m[i]);
- sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
- sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q5_K with sum of q8_1 values
}
return d*sumf_d - dmin*sumf_m;
+
+#else
+
+ const int8_t * s = bq5_K->scales;
+
+ const float d = bq5_K->d;
+
+ const float d8_1 = bq8_1[0].d;
+ const float d8_2 = bq8_1[1].d;
+
+ const int ui1 = *((const int *)bq8_1[0].qs + iqs);
+ const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
+ const int ui3 = *((const int *)bq8_1[1].qs + iqs);
+ const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
+
+ const int * ql = (const int *)bq5_K->qs + iqs;
+ const int vl1 = ql[0];
+ const int vl2 = ql[4];
+
+ const int step = 4 * iqs; // 0, 4, 8, 12
+ const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
+ const int in = step%8; // 0, 4, 0, 4
+ const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
+
+ const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
+ const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
+ const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
+ const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
+
+ const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
+ + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
+
+ return d * sumf_d;
+
+#endif
+
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}
-static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
- rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+ rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK_K, QI4_K, block_q4_K, vec_dot_q4_K_q8_1>
+ // Note: we use QI4_K/2 instead of QI4_K to make the dot product template require 4 groups of quants to be processed per
+ // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
+ // is better amortized.
+ mul_mat_vec_q<QK_K, QI4_K/2, block_q4_K, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
- mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
+ // Note: we use QI5_K/2 instead of QI5_K to make the dot product template require 4 groups of quants to be processed per
+ // kernel call instead of 2. This results in a better perfmance because the cost of computing the k-quant scales
+ // is better amortized.
+ mul_mat_vec_q<QK_K, QI5_K/2, block_q5_K, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
// compute
- rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+ rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
(void) src1;
(void) dst;
// get data from the device into host memory
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
+// try to find operations that can be run concurrently in the graph
+// you should run it again if the topology of your graph changes
+void ggml_metal_graph_find_concurrency(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
+
+// if the graph has been optimized for concurrently dispatch
+bool ggml_metal_if_optimized(struct ggml_metal_context * ctx);
+
// same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
+ int concur_list[GGML_MAX_NODES];
+ int concur_list_len;
+
// custom kernels
#define GGML_METAL_DECL_KERNEL(name) \
id<MTLFunction> function_##name; \
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0;
+ ctx->concur_list_len = 0;
// determine if we can use MPS
if (MPSSupportsMTLDevice(ctx->device)) {
ctx->n_cb = n_cb;
}
+bool ggml_metal_if_optimized(struct ggml_metal_context * ctx) {
+ if (ctx->concur_list_len) {
+ return true;
+ }
+ return false;
+}
+
// finds the Metal buffer that contains the tensor data on the GPU device
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
// Metal buffer based on the host memory pointer
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
}
+void ggml_metal_graph_find_concurrency(
+ struct ggml_metal_context * ctx,
+ struct ggml_cgraph * gf) {
+ int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
+ int nodes_unused[GGML_MAX_NODES];
+
+ for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
+ for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
+ ctx->concur_list_len = 0;
+
+ int n_left = gf->n_nodes;
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
+
+ while (n_left > 0) {
+ // number of nodes at a layer (that can be issued concurrently)
+ int concurrency = 0;
+ for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
+ if (nodes_unused[i]) {
+ // if the requirements for gf->nodes[i] are satisfied
+ int exe_flag=1;
+ // scan all srcs
+ for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
+ struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
+ if (src_cur) {
+ // if is leaf nodes it's satisfied.
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
+
+ // otherwise this src should be the output from previous nodes.
+ int is_found = 0;
+ // scan 2*search_depth back because we inserted barrier.
+ for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
+ if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
+ }
+ if (is_found == 0) {exe_flag = 0; break;}
+ }
+ }
+ if (exe_flag) {
+ // check if nodes[i]'s data will be overwritten by a node before nodes[i].
+ // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
+ int64_t data_start = (int64_t) gf->nodes[i]->data;
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
+ for (int j = n_start; j < i; j++) {
+ if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
+ && gf->nodes[j]->op != GGML_OP_VIEW \
+ && gf->nodes[j]->op != GGML_OP_TRANSPOSE \
+ && gf->nodes[j]->op != GGML_OP_PERMUTE) {
+ if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
+ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
+ continue;
+ } else {
+ exe_flag = 0;
+ }
+ }
+ }
+ }
+ if (exe_flag) {
+ ctx->concur_list[level_pos + concurrency] = i;
+ nodes_unused[i] = 0;
+ concurrency++;
+ ctx->concur_list_len++;
+ }
+ }
+ }
+ n_left -= concurrency;
+ // adding a barrier different layer
+ ctx->concur_list[level_pos + concurrency] = -1;
+ ctx->concur_list_len++;
+ // jump all sorted nodes at nodes_bak
+ while (!nodes_unused[n_start]) {n_start++;}
+ level_pos += concurrency + 1;
+ }
+
+ if (ctx->concur_list_len > GGML_MAX_NODES) {
+ fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
+ }
+}
+
void ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
metal_printf("%s: evaluating graph\n", __func__);
+ // if there is ctx->concur_list, dispatch concurrently
+ // else fallback to serial dispatch
+ MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
+
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
+
+ const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
+ edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;
+
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
- const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
+ const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
dispatch_async(queue, ^{
size_t offs_src0 = 0;
id<MTLComputeCommandEncoder> encoder = nil;
- const int node_start = (cb_idx + 0) * n_nodes_per_cb;
- const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+ const int node_start = (cb_idx + 0) * n_nodes_per_cb;
+ const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
+
+ for (int ind = node_start; ind < node_end; ++ind) {
+ const int i = has_concur ? ctx->concur_list[ind] : ind;
+
+ if (i == -1) {
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+ continue;
+ }
+ [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
+ continue;
+ }
- for (int i = node_start; i < node_end; ++i) {
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
case GGML_OP_ADD:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) {
case GGML_OP_MUL:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
if (ggml_nelements(src1) == ne10) {
case GGML_OP_SCALE:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float scale = *(const float *) src1->data;
case GGML_UNARY_OP_SILU:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_silu];
case GGML_UNARY_OP_RELU:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_relu];
case GGML_UNARY_OP_GELU:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
[encoder setComputePipelineState:ctx->pipeline_gelu];
case GGML_OP_SOFT_MAX:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32;
case GGML_OP_DIAG_MASK_INF:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *)(dst->op_params))[0];
}
} else {
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
int nth0 = 32;
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
switch (src0->type) {
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
- const float eps = 1e-6f;
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
const int nth = 512;
case GGML_OP_NORM:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const float eps = 1e-5f;
case GGML_OP_ALIBI:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
GGML_ASSERT((src0t == GGML_TYPE_F32));
case GGML_OP_ROPE:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int n_past = ((int32_t *) dst->op_params)[0];
case GGML_OP_CONT:
{
if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoder];
+ encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
}
const int nth = 32;
}
}
-// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
-float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
- float4 acc = 0.f;
- device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
- for (int i = 0; i < 16; i+=2) {
- acc[0] += yl[i] * (qs[i / 2] & 0x000F);
- acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
- acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
- acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+ float2 acc = 0.f;
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
}
- return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
+ return d * (sumy * -8.f + acc[0] + acc[1]);
}
-// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
-float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
+// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float m = qb_curr->m;
- float4 acc = 0.f;
- device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
- for (int i = 0; i < 16; i+=2) {
- acc[0] += yl[i] * (qs[i / 2] & 0x000F);
- acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
- acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
- acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+ float2 acc = 0.f;
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ + yl[i + 1] * (qs[i / 2] & 0x0F00);
+ acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
+ + yl[i + 9] * (qs[i / 2] & 0xF000);
}
- return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
+ return d * (acc[0] + acc[1]) + sumy * m;
}
// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-template<typename block_q_type>
+//Note: This is a template, but strictly speaking it only applies to
+// quantizations where the block size is 32. It also does not
+// giard against the number of rows not being divisible by
+// N_DST, so this is another explicit assumption of the implementation.
+template<typename block_q_type, int nr, int nsg, int nw>
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
uint2 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
- device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
+ const int first_row = (r0 * nsg + sgitg) * nr;
+ device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
device const float * y = (device const float *) src1 + r1*ne10;
- float4 y_curr[8]; // src1 vector cache
- float sumf[N_DST]={0.f}, all_sum;
- thread float * yl=(thread float *)y_curr;
+ float yl[16]; // src1 vector cache
+ float sumf[nr]={0.f};
- // each thread in a SIMD group deals with 1 block.
- for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
- float sumy = 0;
- for (int i = 0; i < QK4_0 / 4; i++) {
- y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
- sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
- }
+ const int ix = tiisg/2;
+ const int il = 8*(tiisg%2);
- for (int row = 0; row < N_DST; row++) {
- sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
- }
- }
+ device const float * yb = y + ix * QK4_0 + il;
- // from now loads two rows every time and 16 blocks per row
- int ir = tiisg / (N_SIMDWIDTH / 2);
- int ib = tiisg % (N_SIMDWIDTH / 2);
- for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
- int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
+ // each thread in a SIMD group deals with half a block.
+ for (int ib = ix; ib < nb; ib += nw/2) {
float sumy = 0;
- for (int i = 0; i < QK4_0 / 4; i++) {
- y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
- sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
+ for (int i = 0; i < 8; i += 2) {
+ sumy += yb[i] + yb[i+1];
+ yl[i+0] = yb[i+ 0];
+ yl[i+1] = yb[i+ 1]/256.f;
+ sumy += yb[i+16] + yb[i+17];
+ yl[i+8] = yb[i+16]/16.f;
+ yl[i+9] = yb[i+17]/4096.f;
}
- for (int row = 0; row < N_DST; row+=2) {
- if (nb_start + ib < nb) {
- sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
- }
+ for (int row = 0; row < nr; row++) {
+ sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
}
+
+ yb += QK4_0 * 16;
}
- for (int row = 0; row < N_DST; ++row) {
- all_sum = simd_sum(sumf[row]);
- if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
- dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
+ for (int row = 0; row < nr; ++row) {
+ const float tot = simd_sum(sumf[row]);
+ if (tiisg == 0 && first_row + row < ne01) {
+ dst[r1*ne0 + first_row + row] = tot;
}
}
}
uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mat_q4_1_f32(
uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mat_f16_f32(
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
+static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return
+ tensor->nb[0] == GGML_TYPE_SIZE[tensor->type] &&
+ tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
+}
+
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
+ float eps,
bool inplace) {
bool is_node = false;
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- // TODO: maybe store epsilon here?
+ ggml_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
- struct ggml_tensor * a) {
- return ggml_rms_norm_impl(ctx, a, false);
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_rms_norm_impl(ctx, a, eps, false);
}
struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
- struct ggml_tensor * a) {
- return ggml_rms_norm_impl(ctx, a, true);
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_rms_norm_impl(ctx, a, eps, true);
}
struct ggml_tensor * ggml_rms_norm_back(
}
//struct ggml_tensor * result = ggml_dup_tensor(ctx, q);
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, q->ne);
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, q->n_dims, q->ne);
+
+ int32_t t = masked ? 1 : 0;
+ ggml_set_op_params(result, &t, sizeof(t));
result->op = GGML_OP_FLASH_ATTN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
- result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
return result;
}
}
//struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, a->ne);
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, a->n_dims, a->ne);
result->op = GGML_OP_FLASH_FF;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+ int32_t masked_i = masked ? 1 : 0;
+ ggml_set_op_params(result, &masked_i, sizeof(masked_i));
+
result->op = GGML_OP_FLASH_ATTN_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = q;
result->src[1] = k;
result->src[2] = v;
result->src[3] = d;
- result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
return result;
}
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
const struct ggml_tensor * src0,
const struct ggml_tensor * grad,
struct ggml_tensor * dst) {
- GGML_ASSERT(ggml_is_contiguous(grad));
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, grad));
GGML_TENSOR_UNARY_OP_LOCALS;
- const float eps = 1e-6f; // TODO: make this a parameter
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
} break;
case GGML_OP_FLASH_ATTN:
{
- const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
+ const int32_t t = ggml_get_op_params_i32(tensor, 0);
GGML_ASSERT(t == 0 || t == 1);
const bool masked = t != 0;
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
- int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
+ int32_t t = ggml_get_op_params_i32(tensor, 0);
GGML_ASSERT(t == 0 || t == 1);
bool masked = t != 0;
ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
{
struct ggml_tensor * flash_grad = NULL;
if (src0->grad || src1->grad || tensor->src[2]->grad) {
- int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
+ int32_t t = ggml_get_op_params_i32(tensor, 0);
GGML_ASSERT(t == 0 || t == 1);
bool masked = t != 0;
flash_grad =
}
}
+static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
+
+static size_t hash(void * p) {
+ return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
+}
+
+static bool hash_insert(void * hash_table[], void * p) {
+ size_t h = hash(p);
+
+ // linear probing
+ size_t i = h;
+ while (hash_table[i] != NULL && hash_table[i] != p) {
+ i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
+ if (i == h) {
+ // hash table is full
+ GGML_ASSERT(false);
+ }
+ }
+
+ if (hash_table[i] == p) {
+ return true;
+ }
+
+ // insert
+ hash_table[i] = p;
+ return false;
+}
+
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
if (node->grad == NULL) {
// this usually happens when we generate intermediate nodes from constants in the backward pass
}
// check if already visited
- for (int i = 0; i < cgraph->n_nodes; i++) {
- if (cgraph->nodes[i] == node) {
- return;
- }
- }
-
- for (int i = 0; i < cgraph->n_leafs; i++) {
- if (cgraph->leafs[i] == node) {
- return;
- }
+ if (hash_insert(cgraph->visited_hash_table, node)) {
+ return;
}
for (int i = 0; i < GGML_MAX_SRC; ++i) {
/*.nodes =*/ { NULL },
/*.grads =*/ { NULL },
/*.leafs =*/ { NULL },
+ /*.hash_table =*/ { NULL },
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
if (node->is_param) {
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
- ggml_build_forward_impl(&result, node->grad, true);
+ ggml_build_forward_expand(&result, node->grad);
}
}
ggml_set_param(ctx0, x[i]);
}
- struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0]));
+ struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
}