struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = { };
- gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
// run the computation
ggml_build_forward_expand(&gf, inpL);
- ggml_graph_compute (ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
// run the computation
ggml_build_forward_expand(&gf, inpL);
- ggml_graph_compute (ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
// run the computation
ggml_build_forward_expand(&gf, inpL);
- ggml_graph_compute (ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
// run the computation
ggml_build_forward_expand(&gf, inpL);
- ggml_graph_compute (ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
struct ggml_context * ctx_eval = NULL;
struct ggml_cgraph gfi = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
- gfi.n_threads = n_threads;
// allocate work context
// needed during ggml_graph_compute() to allocate a work tensor
- static size_t buf_size = gfi.work_size; // TODO
+ static size_t buf_size = 128ull*1024*1024; // TODO
static void * buf = malloc(buf_size);
struct ggml_init_params params = {
struct ggml_tensor * input = ggml_graph_get_tensor(&gfi, "input");
memcpy(input->data, digit.data(), ggml_nbytes(input));
- ggml_graph_compute(ctx_work, &gfi);
+ ggml_graph_compute_with_ctx(ctx_work, &gfi, n_threads);
const float * probs_data = ggml_get_data_f32(ggml_graph_get_tensor(&gfi, "probs"));
struct ggml_context * ctx_eval = NULL;
struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
- gf.n_threads = 1;
// allocate work context
- static size_t buf_size = gf.work_size; // TODO
+ static size_t buf_size = 128ull*1024*1024; // TODO
static void * buf = malloc(buf_size);
struct ggml_init_params params = {
- .mem_size = buf_size,
- .mem_buffer = buf,
- .no_alloc = false,
+ /*.mem_size =*/ buf_size,
+ /*.mem_buffer =*/ buf,
+ /*.no_alloc =*/ false,
};
struct ggml_context * ctx_work = ggml_init(params);
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
struct ggml_tensor * input = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_input);
memcpy(input->data, digit.data(), ggml_nbytes(input));
// build / export / run the computation graph
ggml_build_forward_expand(&gf, probs);
- ggml_graph_compute (ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//ggml_graph_print (&gf);
ggml_graph_dump_dot(&gf, NULL, "mnist.dot");
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
// run the computation
ggml_build_forward_expand(&gf, inpL);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
// std::cout << "Qcur" << std::endl;
// print_tensor(Qcur);
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N * ggml_element_size(embd));
// run the computation
ggml_build_forward_expand(&gf, inpL);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
// std::cout << "Qcur" << std::endl;
// print_tensor(Qcur);
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
// run the computation
ggml_build_forward_expand(&gf, inpL);
- ggml_graph_compute (ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
printf("%s: top_p = %.3f\n", __func__, params.top_p);
printf("%s: repeat_last_n = %d\n", __func__, params.repeat_last_n);
printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty);
-
+
int n_past = 0;
int64_t t_sample_us = 0;
std::vector<int32_t> last_n_tokens(model.hparams.n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
-
+
// tokenize the prompt
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(vocab, params.prompt);
embd.push_back(id);
last_n_tokens.erase(last_n_tokens.begin());
- last_n_tokens.push_back(id);
+ last_n_tokens.push_back(id);
} else {
// if here, it means we are still processing the input prompt
for (int k = i; k < embd_inp.size(); k++) {
embd.push_back(embd_inp[k]);
last_n_tokens.erase(last_n_tokens.begin());
- last_n_tokens.push_back(embd_inp[k]);
+ last_n_tokens.push_back(embd_inp[k]);
if (embd.size() >= params.n_batch) {
break;
// run the computation
{
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
ggml_build_forward_expand(&gf, cur);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//ggml_graph_print(&gf);
}
// pre-compute cross-attention memory
{
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
// TODO: hack to disconnect the encoded features from the previous graph
cur->op = GGML_OP_NONE;
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v));
}
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
//ggml_graph_print(&gf);
}
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
- gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, tokens, N*ggml_element_size(embd));
// run the computation
{
ggml_build_forward_expand(&gf, logits);
- ggml_graph_compute (ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
}
// extract logits for all N tokens
struct ggml_cgraph gf = ggml_build_forward(c);
- gf.n_threads = n_threads;
-
double tsum = 0.0;
// heat-up
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
for (int i = 0; i < n_max; ++i) {
const int64_t t0 = ggml_time_us();
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const int64_t t1 = ggml_time_us();
// ggml_set_f32(a, 3.0f);
// ggml_set_f32(b, 4.0f);
//
-// ggml_graph_compute(ctx0, &gf);
+// ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
//
// printf("f = %f\n", ggml_get_f32_1d(f, 0));
//
struct ggml_tensor * src1;
struct ggml_tensor * opt[GGML_MAX_OPT];
- // thread scheduling
- int n_tasks;
-
// performance
int perf_runs;
int64_t perf_cycles;
void * extra; // extra things e.g. for ggml-cuda.cu
- char padding[4];
+ char padding[8];
};
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
+ // the compute plan that needs to be prepared for ggml_graph_compute()
+ // since https://github.com/ggerganov/ggml/issues/287
+ struct ggml_cplan {
+ size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
+ uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
+
+ int n_threads;
+
+ // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
+ int n_tasks[GGML_MAX_NODES];
+ };
+
// computation graph
struct ggml_cgraph {
int n_nodes;
int n_leafs;
- int n_threads;
-
- size_t work_size;
- struct ggml_tensor * work;
struct ggml_tensor * nodes[GGML_MAX_NODES];
struct ggml_tensor * grads[GGML_MAX_NODES];
GGML_API void ggml_set_param(
struct ggml_context * ctx,
- struct ggml_tensor * tensor);
+ struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
- GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
- GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
+ // ggml_graph_plan() has to be called before ggml_graph_compute()
+ // when plan.work_size > 0, caller must allocate memory for plan.work_data
+ GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
+ GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+ GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
+
+ // same as ggml_graph_compute() but the work data is allocated as a part of the context
+ // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
+ GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
#endif //GGML_CUDA_DMMV_F16
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
-typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
-typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
+typedef void (*to_fp32_cuda_t)(const void * __restrict__ x, float * __restrict__ y, int k, cudaStream_t stream);
+typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
typedef void (*ggml_cuda_op_t)(
} block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
-typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs);
//================================= k-quants
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
#define WARP_SIZE 32
+#define MATRIX_ROW_PADDING 256 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#define CUDA_ADD_BLOCK_SIZE 256
#define CUDA_MUL_BLOCK_SIZE 256
//================================== k-quants
-static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float * __restrict__ yy) {
const int i = blockIdx.x;
const block_q2_K * x = (const block_q2_K *) vx;
}
-static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, float * __restrict__ yy) {
const int i = blockIdx.x;
const block_q3_K * x = (const block_q3_K *) vx;
}
#endif
-static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float * __restrict__ yy) {
const block_q4_K * x = (const block_q4_K *) vx;
const int i = blockIdx.x;
#endif
}
-static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float * __restrict__ yy) {
const block_q5_K * x = (const block_q5_K *) vx;
const int i = blockIdx.x;
#endif
}
-static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
+static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, float * __restrict__ yy) {
const block_q6_K * x = (const block_q6_K *) vx;
const int i = blockIdx.x;
#endif
}
-static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
}
}
-static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
}
}
-static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
}
}
-static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
+static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
}
}
-static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
+static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
v.y = x[ib + iqs + 1];
}
-static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
+static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int ndata, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
block_q8_1 * y = (block_q8_1 *) vy;
- const int ib = i / QK8_0; // block index
- const int iqs = i % QK8_0; // quant index
+ const int ib = i / QK8_1; // block index
+ const int iqs = i % QK8_1; // quant index
- const float xi = x[i];
+ const float xi = i < ndata ? x[i] : 0.0f;
float amax = fabsf(xi);
float sum = xi;
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_block(const void * vx, float * y, const int k) {
+static __global__ void dequantize_block(const void * __restrict__ vx, float * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
if (i >= k) {
y[iybs + iqs + y_offset] = v.y;
}
-static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
#endif // __CUDA_ARCH__ >= 600
}
-static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
#endif // __CUDA_ARCH__ >= 600
}
-static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
#endif // __CUDA_ARCH__ >= 600
}
-static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
#endif // __CUDA_ARCH__ >= 600
}
-static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
}
template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
-static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
+static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row >= nrows) {
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
+static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const int row = blockIdx.y*blockDim.y + threadIdx.y;
}
}
-static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
}
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
- const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int channel_stride_x) {
const half * x = (const half *) vx;
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}
-static void quantize_row_q8_1_cuda(const float * x, void * vy, const int k, cudaStream_t stream) {
+static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
- quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k);
+ quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, ndata, k);
}
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0;
- // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
- // However, they have bad performance with Pascal cards.
- // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
- const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
+ const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented;
#endif
if (use_mul_mat_vec_q) {
+ int64_t padded_row_size = ne00 + MATRIX_ROW_PADDING - 1;
+ padded_row_size -= padded_row_size % MATRIX_ROW_PADDING;
size_t as;
- void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as);
- quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, cudaStream_main);
+ void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as);
+ quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, padded_row_size, cudaStream_main);
switch (src0->type) {
case GGML_TYPE_Q4_0:
void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
int nrows = ggml_nrows(tensor);
+
+ const int64_t ne0 = tensor->ne[0];
+
const size_t nb1 = tensor->nb[1];
+
ggml_backend backend = tensor->backend;
struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
memset(extra, 0, sizeof(*extra));
int64_t nrows_split = row_high - row_low;
const size_t offset_split = row_low*nb1;
- const size_t size = ggml_nbytes_split(tensor, nrows_split);
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 256 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += (MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING)
+ * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type);
+ }
- void * buf;
+ char * buf;
CUDA_CHECK(cudaMalloc(&buf, size));
- void * buf_host = (char*)data + offset_split;
+ char * buf_host = (char*)data + offset_split;
+
+ // set padding to 0 to avoid possible NaN values
+ if (size > original_size) {
+ CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
+ }
+
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
struct ggml_metal_context;
-struct ggml_metal_context * ggml_metal_init(void);
+// number of command buffers to use
+struct ggml_metal_context * ggml_metal_init(int n_cb);
void ggml_metal_free(struct ggml_metal_context * ctx);
+// set the number of command buffers to use
+void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
+
// creates a mapping between a host memory buffer and a device memory buffer
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
// - the mapping is used during computation to determine the arguments of the compute kernels
};
struct ggml_metal_context {
+ int n_cb;
+
float * logits;
id<MTLDevice> device;
@implementation GGMLMetalClass
@end
-struct ggml_metal_context * ggml_metal_init(void) {
+struct ggml_metal_context * ggml_metal_init(int n_cb) {
fprintf(stderr, "%s: allocating\n", __func__);
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+ ctx->n_cb = n_cb;
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0;
free(ctx);
}
+void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
+ ctx->n_cb = n_cb;
+}
+
// finds the Metal buffer that contains the tensor data on the GPU device
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
// Metal buffer based on the host memory pointer
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
- const int n_cb = gf->n_threads;
+ const int n_cb = ctx->n_cb;
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
//}
switch (dst->op) {
+ case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
-#if K_QUANTS_PER_ITERATION == 1
+\n#if K_QUANTS_PER_ITERATION == 1\n
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
-#else
+
+\n#else\n
+
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4;
-#endif
+
+\n#endif\n
+
const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is;
const float d = vload_half(0, &x[i].d);
-#if K_QUANTS_PER_ITERATION == 1
+\n#if K_QUANTS_PER_ITERATION == 1\n
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp[16 * ix + tid] += sum;
-#else
+\n#else\n
float sum = 0;
for (int l = 0; l < 4; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
}
tmp[16 * ix + tid] += sum;
-#endif
+\n#endif\n
}
#include "ggml-opencl.h"
#endif
#elif defined(GGML_USE_OPENBLAS)
+#if defined(GGML_BLAS_USE_MKL)
+#include <mkl.h>
+#else
#include <cblas.h>
+#endif
#elif defined(GGML_USE_CUBLAS)
#include "ggml-cuda.h"
#elif defined(GGML_USE_CLBLAST)
/*.src0 =*/ NULL,
/*.src1 =*/ NULL,
/*.opt =*/ { NULL },
- /*.n_tasks =*/ 0,
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
/*.name =*/ { 0 },
/*.extra =*/ NULL,
- /*.pad =*/ { 0 },
+ /*.padding =*/ { 0 },
};
// TODO: this should not be needed as long as we don't rely on aligned SIMD loads
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
- assert(ne00 % 32 == 0);
-
for (int64_t ic = 0; ic < ne11; ++ic) {
vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
}
struct ggml_cgraph result = {
/*.n_nodes =*/ 0,
/*.n_leafs =*/ 0,
- /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
- /*.work_size =*/ 0,
- /*.work =*/ NULL,
/*.nodes =*/ { NULL },
/*.grads =*/ { NULL },
/*.leafs =*/ { NULL },
#endif
struct ggml_compute_state_shared {
- struct ggml_cgraph * cgraph;
+ const struct ggml_cgraph * cgraph;
+ const struct ggml_cplan * cplan;
int64_t perf_node_start_cycles;
int64_t perf_node_start_time_us;
- int n_threads;
+ const int n_threads;
// synchronization primitives
atomic_int n_active; // num active threads
static thread_ret_t ggml_graph_compute_thread(void * data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
- struct ggml_cgraph * cgraph = state->shared->cgraph;
- const int n_threads = state->shared->n_threads;
+ const struct ggml_cgraph * cgraph = state->shared->cgraph;
+ const struct ggml_cplan * cplan = state->shared->cplan;
+
+ const int * n_tasks_arr = cplan->n_tasks;
+ const int n_threads = state->shared->n_threads;
+
set_numa_thread_affinity(state->ith, n_threads);
int node_n = -1;
/*.type =*/ GGML_TASK_FINALIZE,
/*.ith =*/ 0,
/*.nth =*/ 0,
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+ /*.wsize =*/ cplan->work_size,
+ /*.wdata =*/ cplan->work_data,
};
if (node_n != -1) {
/* FINALIZE */
struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
if (GGML_OP_HAS_FINALIZE[node->op]) {
- params.nth = node->n_tasks;
+ params.nth = n_tasks_arr[node_n];
ggml_compute_forward(¶ms, node);
ggml_graph_compute_perf_stats_node(node, state->shared);
}
GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
struct ggml_tensor * node = cgraph->nodes[node_n];
+ const int n_tasks = n_tasks_arr[node_n];
state->shared->perf_node_start_cycles = ggml_perf_cycles();
state->shared->perf_node_start_time_us = ggml_perf_time_us();
- params.nth = node->n_tasks;
+ params.nth = n_tasks;
/* INIT */
if (GGML_OP_HAS_INIT[node->op]) {
ggml_compute_forward(¶ms, node);
}
- if (node->n_tasks == 1) {
+ if (n_tasks == 1) {
// TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
// they do something more efficient than spinning (?)
params.type = GGML_TASK_COMPUTE;
// wait for other threads to finish
const int last = node_n;
do {
- sched_yield();
+ //sched_yield();
node_n = atomic_load(&state->shared->node_n);
} while (node_n == last);
}
/* COMPUTE */
struct ggml_tensor * node = cgraph->nodes[node_n];
+ const int n_tasks = n_tasks_arr[node_n];
struct ggml_compute_params params = {
/*.type =*/ GGML_TASK_COMPUTE,
/*.ith =*/ state->ith,
- /*.nth =*/ node->n_tasks,
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
+ /*.nth =*/ n_tasks,
+ /*.wsize =*/ cplan->work_size,
+ /*.wdata =*/ cplan->work_data,
};
- if (state->ith < node->n_tasks) {
+ if (state->ith < n_tasks) {
ggml_compute_forward(¶ms, node);
}
}
return 0;
}
-void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
- const int n_threads = cgraph->n_threads;
+struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
+ if (n_threads <= 0) {
+ n_threads = GGML_DEFAULT_N_THREADS;
+ }
- struct ggml_compute_state_shared state_shared = {
- /*.cgraph =*/ cgraph,
- /*.perf_node_start_cycles =*/ 0,
- /*.perf_node_start_time_us =*/ 0,
- /*.n_threads =*/ n_threads,
- /*.n_active =*/ n_threads,
- /*.node_n =*/ -1,
- };
- struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
+ size_t work_size = 0;
- // initialize tasks + work buffer
- {
- size_t work_size = 0;
+ struct ggml_cplan cplan;
+ memset(&cplan, 0, sizeof(struct ggml_cplan));
- // thread scheduling for the different operations
- for (int i = 0; i < cgraph->n_nodes; i++) {
- struct ggml_tensor * node = cgraph->nodes[i];
+ // thread scheduling for the different operations + work buffer size estimation
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ int n_tasks = 1;
- switch (node->op) {
- case GGML_OP_CPY:
- case GGML_OP_DUP:
- {
- node->n_tasks = n_threads;
+ struct ggml_tensor * node = cgraph->nodes[i];
- size_t cur = 0;
- if (ggml_is_quantized(node->type)) {
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
- }
+ switch (node->op) {
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ {
+ n_tasks = n_threads;
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_ADD:
- case GGML_OP_ADD1:
- {
- node->n_tasks = n_threads;
+ size_t cur = 0;
+ if (ggml_is_quantized(node->type)) {
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
+ }
- size_t cur = 0;
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_ADD:
+ case GGML_OP_ADD1:
+ {
+ n_tasks = n_threads;
- if (ggml_is_quantized(node->src0->type)) {
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
- }
+ size_t cur = 0;
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_ACC:
- {
- node->n_tasks = n_threads;
+ if (ggml_is_quantized(node->src0->type)) {
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks;
+ }
+
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_ACC:
+ {
+ n_tasks = n_threads;
- size_t cur = 0;
+ size_t cur = 0;
- if (ggml_is_quantized(node->src0->type)) {
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
- }
+ if (ggml_is_quantized(node->src0->type)) {
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks;
+ }
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_SUB:
- case GGML_OP_DIV:
- case GGML_OP_SQR:
- case GGML_OP_SQRT:
- case GGML_OP_LOG:
- case GGML_OP_SUM:
- case GGML_OP_SUM_ROWS:
- case GGML_OP_MEAN:
- case GGML_OP_ARGMAX:
- case GGML_OP_REPEAT:
- case GGML_OP_REPEAT_BACK:
- case GGML_OP_ABS:
- case GGML_OP_SGN:
- case GGML_OP_NEG:
- case GGML_OP_STEP:
- case GGML_OP_TANH:
- case GGML_OP_ELU:
- case GGML_OP_RELU:
- {
- node->n_tasks = 1;
- } break;
- case GGML_OP_MUL:
- case GGML_OP_GELU:
- case GGML_OP_GELU_QUICK:
- case GGML_OP_SILU:
- case GGML_OP_SILU_BACK:
- case GGML_OP_NORM:
- case GGML_OP_RMS_NORM:
- case GGML_OP_RMS_NORM_BACK:
- {
- node->n_tasks = n_threads;
- } break;
- case GGML_OP_MUL_MAT:
- case GGML_OP_OUT_PROD:
- {
- node->n_tasks = n_threads;
-
- // TODO: use different scheduling for different matrix sizes
- //const int nr0 = ggml_nrows(node->src0);
- //const int nr1 = ggml_nrows(node->src1);
-
- //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
- //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
-
- size_t cur = 0;
- const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_SUB:
+ case GGML_OP_DIV:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_LOG:
+ case GGML_OP_SUM:
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
+ case GGML_OP_ARGMAX:
+ case GGML_OP_REPEAT:
+ case GGML_OP_REPEAT_BACK:
+ case GGML_OP_ABS:
+ case GGML_OP_SGN:
+ case GGML_OP_NEG:
+ case GGML_OP_STEP:
+ case GGML_OP_TANH:
+ case GGML_OP_ELU:
+ case GGML_OP_RELU:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_MUL:
+ case GGML_OP_GELU:
+ case GGML_OP_GELU_QUICK:
+ case GGML_OP_SILU:
+ case GGML_OP_SILU_BACK:
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_RMS_NORM_BACK:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_OUT_PROD:
+ {
+ n_tasks = n_threads;
+
+ // TODO: use different scheduling for different matrix sizes
+ //const int nr0 = ggml_nrows(node->src0);
+ //const int nr1 = ggml_nrows(node->src1);
+
+ //n_tasks = MIN(n_threads, MAX(1, nr0/128));
+ //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
+
+ size_t cur = 0;
+ const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
#if defined(GGML_USE_CUBLAS)
- if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
- node->n_tasks = 1; // TODO: this actually is doing nothing
- // the threads are still spinning
- }
- else
+ if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
+ n_tasks = 1; // TODO: this actually is doing nothing
+ // the threads are still spinning
+ } else
#elif defined(GGML_USE_CLBLAST)
- if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
- node->n_tasks = 1; // TODO: this actually is doing nothing
- // the threads are still spinning
- cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
- }
- else
+ if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
+ n_tasks = 1; // TODO: this actually is doing nothing
+ // the threads are still spinning
+ cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
+ } else
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
- node->n_tasks = 1; // TODO: this actually is doing nothing
- // the threads are still spinning
- if (node->src0->type != GGML_TYPE_F32) {
- // here we need memory just for single 2D matrix from src0
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
- }
- } else
-#endif
- if (node->src1->type != vec_dot_type) {
- cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
- } else {
- cur = 0;
+ if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+ n_tasks = 1; // TODO: this actually is doing nothing
+ // the threads are still spinning
+ if (node->src0->type != GGML_TYPE_F32) {
+ // here we need memory just for single 2D matrix from src0
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
}
+ } else
+#endif
+ if (node->src1->type != vec_dot_type) {
+ cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
+ } else {
+ cur = 0;
+ }
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_SCALE:
- {
- node->n_tasks = 1;
- } break;
- case GGML_OP_SET:
- case GGML_OP_CONT:
- case GGML_OP_RESHAPE:
- case GGML_OP_VIEW:
- case GGML_OP_PERMUTE:
- case GGML_OP_TRANSPOSE:
- case GGML_OP_GET_ROWS:
- case GGML_OP_GET_ROWS_BACK:
- case GGML_OP_DIAG:
- case GGML_OP_DIAG_MASK_ZERO:
- {
- node->n_tasks = 1;
- } break;
- case GGML_OP_DIAG_MASK_INF:
- case GGML_OP_SOFT_MAX:
- case GGML_OP_SOFT_MAX_BACK:
- case GGML_OP_ROPE:
- case GGML_OP_ROPE_BACK:
- {
- node->n_tasks = n_threads;
- } break;
- case GGML_OP_ALIBI:
- {
- node->n_tasks = 1; //TODO
- } break;
- case GGML_OP_CLAMP:
- {
- node->n_tasks = 1; //TODO
- } break;
- case GGML_OP_CONV_1D:
- {
- node->n_tasks = n_threads;
-
- GGML_ASSERT(node->src0->ne[3] == 1);
- GGML_ASSERT(node->src1->ne[2] == 1);
- GGML_ASSERT(node->src1->ne[3] == 1);
-
- size_t cur = 0;
- const int nk = node->src0->ne[0];
-
- if (node->src0->type == GGML_TYPE_F16 &&
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_SCALE:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_SET:
+ case GGML_OP_CONT:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_GET_ROWS:
+ case GGML_OP_GET_ROWS_BACK:
+ case GGML_OP_DIAG:
+ case GGML_OP_DIAG_MASK_ZERO:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_DIAG_MASK_INF:
+ case GGML_OP_SOFT_MAX:
+ case GGML_OP_SOFT_MAX_BACK:
+ case GGML_OP_ROPE:
+ case GGML_OP_ROPE_BACK:
+ {
+ n_tasks = n_threads;
+ } break;
+ case GGML_OP_ALIBI:
+ {
+ n_tasks = 1; //TODO
+ } break;
+ case GGML_OP_CLAMP:
+ {
+ n_tasks = 1; //TODO
+ } break;
+ case GGML_OP_CONV_1D:
+ {
+ n_tasks = n_threads;
+
+ GGML_ASSERT(node->src0->ne[3] == 1);
+ GGML_ASSERT(node->src1->ne[2] == 1);
+ GGML_ASSERT(node->src1->ne[3] == 1);
+
+ size_t cur = 0;
+ const int nk = node->src0->ne[0];
+
+ if (node->src0->type == GGML_TYPE_F16 &&
node->src1->type == GGML_TYPE_F32) {
- cur = sizeof(ggml_fp16_t)*(
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
- );
- } else if (node->src0->type == GGML_TYPE_F32 &&
- node->src1->type == GGML_TYPE_F32) {
- cur = sizeof(float)*(
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
- );
- } else {
- GGML_ASSERT(false);
- }
+ cur = sizeof(ggml_fp16_t)*(
+ nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+ ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+ );
+ } else if (node->src0->type == GGML_TYPE_F32 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*(
+ nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
+ ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+ );
+ } else {
+ GGML_ASSERT(false);
+ }
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_CONV_2D:
- {
- node->n_tasks = n_threads;
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_CONV_2D:
+ {
+ n_tasks = n_threads;
- GGML_ASSERT(node->src1->ne[3] == 1);
+ GGML_ASSERT(node->src1->ne[3] == 1);
- const int64_t ne00 = node->src0->ne[0]; // W
- const int64_t ne01 = node->src0->ne[1]; // H
- const int64_t ne02 = node->src0->ne[2]; // C
- const int64_t ne03 = node->src0->ne[3]; // N
+ const int64_t ne00 = node->src0->ne[0]; // W
+ const int64_t ne01 = node->src0->ne[1]; // H
+ const int64_t ne02 = node->src0->ne[2]; // C
+ const int64_t ne03 = node->src0->ne[3]; // N
- const int64_t ne10 = node->src1->ne[0]; // W
- const int64_t ne11 = node->src1->ne[1]; // H
- const int64_t ne12 = node->src1->ne[2]; // C
+ const int64_t ne10 = node->src1->ne[0]; // W
+ const int64_t ne11 = node->src1->ne[1]; // H
+ const int64_t ne12 = node->src1->ne[2]; // C
- const int64_t nk = ne00*ne01;
+ const int64_t nk = ne00*ne01;
- UNUSED(ne02);
- UNUSED(ne03);
- UNUSED(nk);
+ UNUSED(ne02);
+ UNUSED(ne03);
+ UNUSED(nk);
- size_t cur = 0;
+ size_t cur = 0;
- if (node->src0->type == GGML_TYPE_F16 &&
+ if (node->src0->type == GGML_TYPE_F16 &&
node->src1->type == GGML_TYPE_F32) {
- cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
- } else if (node->src0->type == GGML_TYPE_F32 &&
- node->src1->type == GGML_TYPE_F32) {
- cur = sizeof(float)* (ne10*ne11*ne12);
- } else {
- GGML_ASSERT(false);
- }
+ cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
+ } else if (node->src0->type == GGML_TYPE_F32 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(float)* (ne10*ne11*ne12);
+ } else {
+ GGML_ASSERT(false);
+ }
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_FLASH_ATTN:
- {
- node->n_tasks = n_threads;
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_FLASH_ATTN:
+ {
+ n_tasks = n_threads;
- size_t cur = 0;
+ size_t cur = 0;
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
- if (node->src1->type == GGML_TYPE_F32) {
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
- }
+ if (node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
+ }
- if (node->src1->type == GGML_TYPE_F16) {
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
- }
+ if (node->src1->type == GGML_TYPE_F16) {
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
+ }
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_FLASH_FF:
- {
- node->n_tasks = n_threads;
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_FLASH_FF:
+ {
+ n_tasks = n_threads;
- size_t cur = 0;
+ size_t cur = 0;
- if (node->src1->type == GGML_TYPE_F32) {
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
- }
+ if (node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
+ }
- if (node->src1->type == GGML_TYPE_F16) {
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
- }
+ if (node->src1->type == GGML_TYPE_F16) {
+ cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
+ }
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_FLASH_ATTN_BACK:
- {
- node->n_tasks = n_threads;
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_FLASH_ATTN_BACK:
+ {
+ n_tasks = n_threads;
- size_t cur = 0;
+ size_t cur = 0;
- const int64_t D = node->src0->ne[0];
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
- const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
- if (node->src1->type == GGML_TYPE_F32) {
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
- }
+ const int64_t D = node->src0->ne[0];
+ const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
+ if (node->src1->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+ }
- if (node->src1->type == GGML_TYPE_F16) {
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
- }
+ if (node->src1->type == GGML_TYPE_F16) {
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
+ }
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_WIN_PART:
- case GGML_OP_WIN_UNPART:
- case GGML_OP_MAP_UNARY:
- case GGML_OP_MAP_BINARY:
- case GGML_OP_MAP_CUSTOM1:
- case GGML_OP_MAP_CUSTOM2:
- case GGML_OP_MAP_CUSTOM3:
- {
- node->n_tasks = 1;
- } break;
- case GGML_OP_CROSS_ENTROPY_LOSS:
- {
- node->n_tasks = n_threads;
-
- size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
-
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
- {
- node->n_tasks = n_threads;
-
- size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
-
- work_size = MAX(work_size, cur);
- } break;
- case GGML_OP_NONE:
- {
- node->n_tasks = 1;
- } break;
- case GGML_OP_COUNT:
- {
- GGML_ASSERT(false);
- } break;
- }
- }
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_WIN_PART:
+ case GGML_OP_WIN_UNPART:
+ case GGML_OP_MAP_UNARY:
+ case GGML_OP_MAP_BINARY:
+ case GGML_OP_MAP_CUSTOM1:
+ case GGML_OP_MAP_CUSTOM2:
+ case GGML_OP_MAP_CUSTOM3:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_CROSS_ENTROPY_LOSS:
+ {
+ n_tasks = n_threads;
+
+ size_t cur = ggml_type_size(node->type)*(n_tasks + node->src0->ne[0]*n_tasks);
+
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ {
+ n_tasks = n_threads;
+
+ size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks;
- if (cgraph->work != NULL && work_size > cgraph->work_size) {
- GGML_ASSERT(false); // TODO: better handling
+ work_size = MAX(work_size, cur);
+ } break;
+ case GGML_OP_NONE:
+ {
+ n_tasks = 1;
+ } break;
+ case GGML_OP_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
}
- if (work_size > 0 && cgraph->work == NULL) {
- cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
+ cplan.n_tasks[i] = n_tasks;
+ }
+
+ if (work_size > 0) {
+ work_size += CACHE_LINE_SIZE*(n_threads - 1);
+ }
- GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
- cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
+ cplan.n_threads = n_threads;
+ cplan.work_size = work_size;
+ cplan.work_data = NULL;
+
+ return cplan;
+}
+
+void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
+ {
+ GGML_ASSERT(cplan);
+ GGML_ASSERT(cplan->n_threads > 0);
+
+ if (cplan->work_size > 0) {
+ GGML_ASSERT(cplan->work_data);
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
+ if (cgraph->nodes[i]->op != GGML_OP_NONE) {
+ GGML_ASSERT(cplan->n_tasks[i] > 0);
+ }
}
}
+ const int n_threads = cplan->n_threads;
+
+ struct ggml_compute_state_shared state_shared = {
+ /*.cgraph =*/ cgraph,
+ /*.cgraph_plan =*/ cplan,
+ /*.perf_node_start_cycles =*/ 0,
+ /*.perf_node_start_time_us =*/ 0,
+ /*.n_threads =*/ n_threads,
+ /*.n_active =*/ n_threads,
+ /*.node_n =*/ -1,
+ };
+ struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
+
// create thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; ++j) {
}
}
+void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
+
+ struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
+ GGML_ASSERT(buf);
+
+ cplan.work_data = buf->data;
+
+ ggml_graph_compute(cgraph, &cplan);
+}
+
struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
for (int i = 0; i < cgraph->n_leafs; i++) {
struct ggml_tensor * leaf = cgraph->leafs[i];
const int64_t * ne = tensor->ne;
const size_t * nb = tensor->nb;
- fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
+ fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
arg,
ggml_type_name(tensor->type),
ggml_op_name (tensor->op),
tensor->n_dims,
ne[0], ne[1], ne[2], ne[3],
nb[0], nb[1], nb[2], nb[3],
- tensor->n_tasks,
tensor->data,
tensor->name);
}
struct ggml_cgraph * gb) {
GGML_ASSERT(ggml_is_scalar(f));
- gf->n_threads = params.n_threads;
- gb->n_threads = params.n_threads;
-
// these will store the parameters we want to optimize
struct ggml_tensor * ps[GGML_MAX_PARAMS];
// compute the function value
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute(ctx, gb);
+
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
opt->adam.fx_best = opt->adam.fx_prev;
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute(ctx, gb);
+
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
const float fx = ggml_get_f32_1d(f, 0);
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute(ctx, gb);
+
+ ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
ggml_opt_get_grad(np, ps, g);
}
}
- gf->n_threads = params.n_threads;
- gb->n_threads = params.n_threads;
-
const int m = params.lbfgs.m;
// these will store the parameters we want to optimize
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute(ctx, gb);
+
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
ggml_opt_get_grad(np, ps, g);
return 1;
}
+ const int n_threads = 1;
+
int M = atoi(argv[1]);
int N = atoi(argv[2]);
int K = atoi(argv[3]);
dst2 = ggml_mul_mat(ctx0, s0_f32, s1_f32);
struct ggml_cgraph gf = ggml_build_forward(dst2);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
}
{
dst3 = ggml_mul_mat(ctx0, s0_f16, s1_f32);
struct ggml_cgraph gf = ggml_build_forward(dst3);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
}
bool ok_blas = true;
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
+
#define MAX_NARGS 3
#undef MIN
int irand(int n) {
if (n == 0) return 0;
- else return rand()%n;
+ return rand()%n;
}
void get_random_dims(int64_t * dims, int ndims) {
float get_element(const struct ggml_tensor * t, int idx) {
if (t->type == GGML_TYPE_F32) {
return ((float *)t->data)[idx];
- } else if (t->type == GGML_TYPE_I32) {
+ }
+
+ if (t->type == GGML_TYPE_I32) {
return ((int32_t *)t->data)[idx];
- } else {
- assert(false);
- return INFINITY;
}
+
+ assert(false);
+ return INFINITY;
}
void set_element(struct ggml_tensor * t, int idx, float value) {
}
struct ggml_cgraph gf = ggml_build_forward (f);
- gf.n_threads = n_threads;
-
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
- gb.n_threads = n_threads;
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+
ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
// ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
// ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
const float xm = x0 - eps;
const float xp = x0 + eps;
set_element(x[i], k, xp);
- ggml_graph_compute(ctx0, &gf);
+
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f0 = ggml_get_f32_1d(f, 0);
set_element(x[i], k, xm);
- ggml_graph_compute(ctx0, &gf);
- const float f1 = ggml_get_f32_1d(f, 0);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
+ const float f1 = ggml_get_f32_1d(f, 0);
const float g0 = (f0 - f1)/(2.0f*eps);
set_element(x[i], k, x0);
// compute gradient using backward graph
ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
const float g1 = get_element(x[i]->grad, k);
const float error_abs = fabsf(g0 - g1);
- const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabs(g0) : 0;
+ const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
if (error_abs > max_error_abs || error_rel > max_error_rel) {
printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
float eps,
float max_error_abs,
float max_error_rel) {
+ const int n_threads = 1;
struct ggml_cgraph gf = ggml_build_forward (f);
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
const float x0 = get_element(x[i], k);
set_element(x[i], k, x0 + eps);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f0 = ggml_get_f32_1d(f, 0);
set_element(x[i], k, x0 - eps);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f1 = ggml_get_f32_1d(f, 0);
// compute gradient using backward graph
ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
const float g1 = get_element(x[i]->grad, k);
if (argc > 1) {
niter = atoi(argv[1]);
}
+
+ int n_threads = 1;
+
for (int iter = 0; iter < niter; ++iter) {
printf("test-mul-mat0: iter:%d/%d\n", iter, niter);
struct ggml_context * ctx0 = ggml_init(params);
check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
} else {
struct ggml_cgraph gf = ggml_build_forward(m);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
}
check_mat_mul(m, x[1], x[0]);
check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
} else {
struct ggml_cgraph gf = ggml_build_forward(m);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
}
check_mat_mul(m, x[1], x[0]);
#define MAX_NARGS 2
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
//
// logging
#define GGML_PRINT(...) printf(__VA_ARGS__)
-float frand() {
+float frand(void) {
return (float)rand()/(float)RAND_MAX;
}
((float *)t->data)[idx] = value;
}
-int main(int argc, const char ** argv) {
+int main(void) {
struct ggml_init_params params = {
.mem_size = 1024*1024*1024,
.mem_buffer = NULL,
struct ggml_tensor * d = ggml_sub(ctx, c, ab);
struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d));
-
struct ggml_cgraph ge = ggml_build_forward(e);
- ggml_graph_reset (&ge);
- ggml_graph_compute(ctx, &ge);
+ ggml_graph_reset(&ge);
+
+ ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+
const float fe = ggml_get_f32_1d(e, 0);
printf("%s: e = %.4f\n", __func__, fe);
ggml_opt(ctx, opt_params, e);
- ggml_graph_reset (&ge);
- ggml_graph_compute(ctx, &ge);
+ ggml_graph_reset(&ge);
+
+ ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
+
const float fe_opt = ggml_get_f32_1d(e, 0);
printf("%s: original e = %.4f\n", __func__, fe);
printf("%s: optimized e = %.4f\n", __func__, fe_opt);
#include <stdlib.h>
int main(int argc, const char ** argv) {
+ const int n_threads = 2;
+
struct ggml_init_params params = {
.mem_size = 128*1024*1024,
.mem_buffer = NULL,
ggml_graph_reset(&gf);
ggml_set_f32(f->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("f = %f\n", ggml_get_f32_1d(f, 0));
printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
ggml_graph_reset(&gf);
ggml_set_f32(f->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("f = %f\n", ggml_get_f32_1d(f, 0));
printf("df/dx = %f\n", ggml_get_f32_1d(x->grad, 0));
ggml_graph_reset(&gf);
ggml_set_f32(y->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("y = %f\n", ggml_get_f32_1d(y, 0));
printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
ggml_set_f32(g1->grad, 1.0f);
ggml_set_f32(g2->grad, 1.0f);
- ggml_graph_compute(ctx0, &gbb);
+ ggml_graph_compute_with_ctx(ctx0, &gbb, n_threads);
printf("H * [1, 1] = [ %f %f ]\n", ggml_get_f32_1d(x1->grad, 0), ggml_get_f32_1d(x2->grad, 0));
ggml_graph_reset(&gf);
ggml_set_f32(y->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("y = %f\n", ggml_get_f32_1d(y, 0));
printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
ggml_graph_reset(&gf);
ggml_set_f32(y->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("y = %f\n", ggml_get_f32_1d(y, 0));
printf("df/dx1 = %f\n", ggml_get_f32_1d(x1->grad, 0));
ggml_set_f32(g2->grad, 1.0f);
ggml_set_f32(g3->grad, 1.0f);
- ggml_graph_compute(ctx0, &gbb);
+ ggml_graph_compute_with_ctx(ctx0, &gbb, n_threads);
printf("H * [1, 1, 1] = [ %f %f %f ]\n",
ggml_get_f32_1d(x1->grad, 0),
ggml_graph_reset(&gf);
ggml_set_f32(y->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("y = %f\n", ggml_get_f32_1d(y, 0));
printf("df/dx1 = %f %f %f\n",
ggml_graph_reset(&gf);
ggml_set_f32(y->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("y = %f\n", ggml_get_f32_1d(y, 0));
printf("df/dx1 = %f %f %f\n",
ggml_graph_reset(&gf);
ggml_set_f32(y->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("y = %f\n", ggml_get_f32_1d(y, 0));
printf("df/dx1 = %f %f %f\n",
ggml_graph_reset(&gf);
ggml_set_f32(y->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("y = %f\n", ggml_get_f32_1d(y, 0));
printf("df/dx1 = %f %f %f\n",
ggml_graph_reset(&gf);
ggml_set_f32(y->grad, 1.0f);
- ggml_graph_compute(ctx0, &gb);
+ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
printf("y = %f\n", ggml_get_f32_1d(y, 0));
printf("df/dx1 = %f %f %f\n",
});\r
\r
pub fn main() !void {\r
+ const n_threads = 2;\r
+\r
const params = .{\r
.mem_size = 128*1024*1024,\r
.mem_buffer = null,\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(f.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("f = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});\r
std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(f.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("f = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});\r
std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(y.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});\r
_ = c.ggml_set_f32(g1.*.grad, 1.0);\r
_ = c.ggml_set_f32(g2.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gbb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gbb), n_threads);\r
\r
std.debug.print("H * [1, 1] = [ {d:.6} {d:.6} ]\n", .{c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 0)});\r
\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(y.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(y.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});\r
_ = c.ggml_set_f32(g2.*.grad, 1.0);\r
_ = c.ggml_set_f32(g3.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gbb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gbb), n_threads);\r
\r
std.debug.print("H * [1, 1, 1] = [ {d:.6} {d:.6} {d:.6}]\n",\r
.{\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(y.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(y.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(y.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(y.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r
c.ggml_graph_reset(@constCast(&gf));\r
_ = c.ggml_set_f32(y.*.grad, 1.0);\r
\r
- c.ggml_graph_compute(ctx0, @constCast(&gb));\r
+ c.ggml_graph_compute_with_ctx(ctx0, @constCast(&gb), n_threads);\r
\r
std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});\r
std.debug.print("df/dx1 = {d:.6} {d:.6} {d:.6}\n",\r