int32_t n_rot = 20; // rotary_pct[25%] * (n_embd / n_head)
int32_t par_res = 1; // 1 = true, 0 = false
int32_t ftype = GGML_FTYPE_MOSTLY_F16;
+ float eps = 1e-5;
};
const std::string INSTRUCTION_KEY = "### Instruction:";
// feed-forward network
ggml_tensor * gpt_neox_ff(
- const dollyv2_layer &layer,
- ggml_context * ctx0,
- ggml_tensor * inp) {
- ggml_tensor * cur = ggml_norm(ctx0, inp);
+ const dollyv2_layer & layer,
+ ggml_context * ctx0,
+ ggml_tensor * inp,
+ float eps) {
+ ggml_tensor * cur = ggml_norm(ctx0, inp, eps);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
// self-attention
{
{
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
if (hparams.par_res == 0) {
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL);
- cur = gpt_neox_ff(model.layers[il], ctx0, inpFF);
+ cur = gpt_neox_ff(model.layers[il], ctx0, inpFF, hparams.eps);
// input for next layer
inpL = ggml_add(ctx0, cur, inpFF);
// this is independent of the self-attention result, so it could be done in parallel to the self-attention
// note here we pass inpL instead of cur
- cur = gpt_neox_ff(model.layers[il], ctx0, inpL);
+ cur = gpt_neox_ff(model.layers[il], ctx0, inpL, hparams.eps);
// layer input + FF
cur = ggml_add(ctx0, cur, inpFF);
// norm
{
- inpL = ggml_norm(ctx0, inpL);
+ inpL = ggml_norm(ctx0, inpL, hparams.eps);
// inpL = ln_f_g*inpL + ln_f_b
inpL = ggml_add(ctx0,
int32_t n_head = 12;
int32_t n_layer = 12;
int32_t ftype = 1;
+ float eps = 1e-5;
};
struct gpt2_layer {
// norm
{
// [ 768, N]
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
{
// norm
{
- cur = ggml_norm(ctx0, inpFF);
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
// norm
{
// [ 768, N]
- inpL = ggml_norm(ctx0, inpL);
+ inpL = ggml_norm(ctx0, inpL, hparams.eps);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
int32_t n_layer = 28;
int32_t n_rot = 64;
int32_t ftype = 1;
+ float eps = 1e-5;
};
struct gptj_layer {
// norm
{
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_1_g*cur + ln_1_b
cur = ggml_add(ctx0,
// norm
{
- inpL = ggml_norm(ctx0, inpL);
+ inpL = ggml_norm(ctx0, inpL, hparams.eps);
// inpL = ln_f_g*inpL + ln_f_b
inpL = ggml_add(ctx0,
int32_t n_rot = 32; // rotary_pct * (n_embd / n_head)
int32_t par_res = 1; // 1 = true, 0 = false
int32_t ftype = 1;
+ float eps = 1e-5;
};
struct gpt_neox_layer {
// feed-forward network
ggml_tensor * gpt_neox_ff(
- const gpt_neox_layer &layer,
- ggml_context * ctx0,
- ggml_tensor * inp) {
- ggml_tensor * cur = ggml_norm(ctx0, inp);
+ const gpt_neox_layer & layer,
+ ggml_context * ctx0,
+ ggml_tensor * inp,
+ float eps) {
+ ggml_tensor * cur = ggml_norm(ctx0, inp, eps);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
// self-attention
{
{
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
if (hparams.par_res == 0) {
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL);
- cur = gpt_neox_ff(model.layers[il], ctx0, inpFF);
+ cur = gpt_neox_ff(model.layers[il], ctx0, inpFF, hparams.eps);
// input for next layer
inpL = ggml_add(ctx0, cur, inpFF);
// this is independent of the self-attention result, so it could be done in parallel to the self-attention
// note here we pass inpL instead of cur
- cur = gpt_neox_ff(model.layers[il], ctx0, inpL);
+ cur = gpt_neox_ff(model.layers[il], ctx0, inpL, hparams.eps);
// layer input + FF
cur = ggml_add(ctx0, cur, inpFF);
// norm
{
- inpL = ggml_norm(ctx0, inpL);
+ inpL = ggml_norm(ctx0, inpL, hparams.eps);
// inpL = ln_f_g*inpL + ln_f_b
inpL = ggml_add(ctx0,
const int n_head = hparams.n_heads;
const int n_vocab = hparams.n_vocab;
const int n_ctx = hparams.n_ctx;
+ const float eps = 1e-5;
static size_t buf_size = 256u * 1024 * 1024;
static void * buf = malloc(buf_size);
// a = self.ln_1(x)
{
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, eps);
cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_1_weight, cur), cur);
}
// m = self.ln_2(x)
{
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, eps);
cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_2_weight, cur), cur);
}
// norm
{
- inpL = ggml_norm(ctx0, inpL);
+ inpL = ggml_norm(ctx0, inpL, eps);
// inpL = ln_f_g*inpL
inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL);
}
const int n_head = hparams.n_heads;
const int n_vocab = hparams.n_vocab;
const int n_ctx = hparams.max_seq_len;
+ const float eps = 1e-5;
static size_t buf_size = 256u * 1024 * 1024;
static void * buf = malloc(buf_size);
// a = self.ln_1(x)
{
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, eps);
cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_1_weight, cur), cur);
}
// m = self.ln_2(x)
{
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, eps);
cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].norm_2_weight, cur), cur);
}
// norm
{
- inpL = ggml_norm(ctx0, inpL);
+ inpL = ggml_norm(ctx0, inpL, eps);
// inpL = ln_f_g*inpL
inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL);
}
fname_out = fname_out.replace(".bin", "-" + ftype_str[ftype] + ".bin")
+# Default params are set to sam_vit_b checkpoint
+n_enc_state = 768
+n_enc_layers = 12
+n_enc_heads = 12
+n_enc_out_chans = 256
+n_pt_embd = 4
+
model = torch.load(fname_model, map_location="cpu")
+for k, v in model.items():
+ print(k, v.shape)
+ if k == "image_encoder.blocks.0.norm1.weight":
+ n_enc_state = v.shape[0]
-# TODO: determine based on model data
-# TODO: add decoder / prompt encoder if needed
-hparams = {
- "n_enc_state": 768,
- "n_enc_layers": 12,
- "n_enc_heads": 12,
- "n_enc_out_chans": 256,
+if n_enc_state == 1024: # sam_vit_l
+ n_enc_layers = 24
+ n_enc_heads = 16
+elif n_enc_state == 1280: # sam_vit_h
+ n_enc_layers = 32
+ n_enc_heads = 16
- "n_pt_embd": 4,
+hparams = {
+ "n_enc_state": n_enc_state,
+ "n_enc_layers": n_enc_layers,
+ "n_enc_heads": n_enc_heads,
+ "n_enc_out_chans": n_enc_out_chans,
+ "n_pt_embd": n_pt_embd,
}
print(hparams)
#data = tf.train.load_variable(dir_model, name).squeeze()
#data = v.numpy().squeeze()
data = v.numpy()
- n_dims = len(data.shape);
+ n_dims = len(data.shape)
# for efficiency - transpose some matrices
# "model/h.*/attn/c_attn/w"
# keep it in F32 since the data is small
if name == "image_encoder.patch_embed.proj.bias":
data = data.reshape(1, data.shape[0], 1, 1)
- n_dims = len(data.shape);
+ n_dims = len(data.shape)
dshape = data.shape
print(" New shape: ", dshape)
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
- fout.write(str);
+ fout.write(str)
# data
data.tofile(fout)
#include <string>
#include <vector>
+// void print_t_f32(const char* title, struct ggml_tensor * t, int n = 10) {
+// printf("%s\n", title);
+// float * data = (float *)t->data;
+// printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
+// printf("First & Last %d elements:\n", n);
+// for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
+// printf("%.5f ", data[i]);
+// if (i != 0 && i % t->ne[0] == 0) {
+// printf("\n");
+// }
+// }
+// printf("\n");
+// for (int i = 0; i < std::min((int) (t->ne[0]*t->ne[1]), n); i++) {
+// printf("%.5f ", data[ggml_nelements(t) - n + i]);
+// if ((ggml_nelements(t) - n + i) % t->ne[0] == 0) {
+// printf("\n");
+// }
+// }
+// printf("\n");
+// double sum = 0.0;
+// for (int i = 0; i < ggml_nelements(t); i++) {
+// sum += data[i];
+// }
+// printf("sum: %f\n\n", sum);
+// }
+
// default hparams (ViT-B SAM)
struct sam_hparams {
int32_t n_enc_state = 768;
float iou_threshold = 0.88f;
float stability_score_threshold = 0.95f;
float stability_score_offset = 1.0f;
+ float eps = 1e-6f;
+ float eps_decoder_transformer = 1e-5f;
int32_t n_enc_head_dim() const { return n_enc_state / n_enc_head; }
int32_t n_img_size() const { return 1024; }
struct ggml_tensor * layer,
int n_channels,
struct ggml_tensor * w,
- struct ggml_tensor * b) {
+ struct ggml_tensor * b,
+ float eps) {
// LayerNorm2d
// normalize along channel dimmension
// TODO: better implementation
layer = ggml_permute(ctx0,
- ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3))),
+ ggml_norm(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, layer, 1, 2, 0, 3)), eps),
2, 0, 1, 3);
layer = ggml_add(ctx0,
// norm
// ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L168
{
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
{
// norm
{
- cur = ggml_norm(ctx0, inpFF);
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
// cur = mlp_ln_w*cur + mlp_ln_b
cur = ggml_add(ctx0,
cur = ggml_conv_2d_sk_p0(ctx0, enc.neck_conv_0, cur);
- cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b);
+ cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_0_w, enc.neck_norm_0_b, hparams.eps);
cur = ggml_conv_2d_s1_ph(ctx0, enc.neck_conv_1, cur);
- cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b);
+ cur = sam_layer_norm_2d(ctx0, cur, n_enc_out_chans, enc.neck_norm_1_w, enc.neck_norm_1_b, hparams.eps);
// TODO: avoid copy
cur = ggml_cpy(ctx0, cur, state.embd_img);
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf = {};
- // auto print_t_f32 = [&](const char* title, struct ggml_tensor * t) {
- // printf("%s\n", title);
- // float * data = (float *)t->data;
- // printf("dims: %jd %jd %jd %jd f32\n", t->ne[0], t->ne[1], t->ne[2], t->ne[3]);
- // printf("First 10 elements:\n");
- // for (int i = 0; i < std::min((int) t->ne[0], 10); i++) {
- // printf("%f ", data[i]);
- // }
- // printf("\n");
- // double sum = 0.0;
- // for (int i = 0; i < ggml_nelements(t); i++) {
- // sum += data[i];
- // }
- // printf("sum: %f\n\n", sum);
- // };
// print_t_f32("embd_img", state.embd_img);
// print_t_f32("embd_prompt_dense", state.embd_prompt_dense);
// print_t_f32("embd_prompt_sparse", state.embd_prompt_sparse);
queries = ggml_add(ctx0, queries, self_attn);
}
- queries = ggml_norm(ctx0, queries);
+ queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
queries = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, tfm_layer.norm1_w, queries),
struct ggml_tensor * cross_attn_token_to_img = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_token_to_img, q_1, k_1, keys, ctx0, model);
queries = ggml_add(ctx0, queries, cross_attn_token_to_img);
- queries = ggml_norm(ctx0, queries);
+ queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
queries = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, tfm_layer.norm2_w, queries),
mlp_out);
queries = ggml_add(ctx0, queries, mlp_out);
- queries = ggml_norm(ctx0, queries);
+ queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
queries = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, tfm_layer.norm3_w, queries),
struct ggml_tensor * cross_attn_img_to_token = sam_decode_mask_transformer_attn(tfm_layer.cross_attn_img_to_token, k_2, q_2, queries, ctx0, model);
keys = ggml_add(ctx0, keys, cross_attn_img_to_token);
- keys = ggml_norm(ctx0, keys);
+ keys = ggml_norm(ctx0, keys, hparams.eps_decoder_transformer);
keys = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, tfm_layer.norm4_w, keys),
struct ggml_tensor * final_attn_token_to_img = sam_decode_mask_transformer_attn(dec.transformer_final_attn_token_to_img, q, k, keys, ctx0, model);
queries = ggml_add(ctx0, queries, final_attn_token_to_img);
- queries = ggml_norm(ctx0, queries);
+ queries = ggml_norm(ctx0, queries, hparams.eps_decoder_transformer);
queries = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, dec.transformer_norm_final_w, queries),
{
// ConvTranspose2d
keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_0_w, keys, 2);
- keys = ggml_add(ctx0, ggml_repeat(ctx0, dec.output_upscaling_0_b, keys), keys);
- keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b);
+ keys = ggml_add(ctx0, keys, ggml_repeat(ctx0,
+ ggml_reshape_3d(ctx0, dec.output_upscaling_0_b, 1, 1, dec.output_upscaling_0_b->ne[0]),
+ keys));
+
+ keys = sam_layer_norm_2d(ctx0, keys, n_img_embd, dec.output_upscaling_1_w, dec.output_upscaling_1_b, hparams.eps);
// GELU activation
keys = ggml_gelu(ctx0, keys);
// ConvTranspose2d
keys = ggml_conv_transpose_2d_p0(ctx0, dec.output_upscaling_3_w, keys, 2);
- keys = ggml_add(ctx0, ggml_repeat(ctx0, dec.output_upscaling_3_b, keys), keys);
-
+ keys = ggml_add(ctx0, ggml_repeat(ctx0,
+ ggml_reshape_3d(ctx0, dec.output_upscaling_3_b, 1, 1, dec.output_upscaling_3_b->ne[0]),
+ keys), keys);
// GELU activation
keys = ggml_gelu(ctx0, keys);
upscaled_embedding = ggml_reshape_3d(ctx0, keys, keys->ne[0]*keys->ne[1], keys->ne[2], keys->ne[3]);
int32_t n_head = 16;
int32_t n_layer = 24;
int32_t ftype = 1;
+ float eps = 1e-5;
};
struct starcoder_layer {
// norm
{
// [ 768, N]
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
{
// norm
{
- cur = ggml_norm(ctx0, inpFF);
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
// norm
{
// [ 768, N]
- inpL = ggml_norm(ctx0, inpL);
+ inpL = ggml_norm(ctx0, inpL, hparams.eps);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
int32_t n_head = 16;
int32_t n_layer = 24;
int32_t ftype = 1;
+ float eps = 1e-5;
};
struct starcoder_layer {
// norm
{
// [ 768, N]
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_1_g*cur + ln_1_b
// [ 768, N]
{
// norm
{
- cur = ggml_norm(ctx0, inpFF);
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
// norm
{
// [ 768, N]
- inpL = ggml_norm(ctx0, inpL);
+ inpL = ggml_norm(ctx0, inpL, hparams.eps);
// inpL = ln_f_g*inpL + ln_f_b
// [ 768, N]
int32_t n_text_layer = 4;
int32_t n_mels = 80;
int32_t ftype = 1;
+ float eps = 1e-5;
};
// audio encoding layer
{
wstate.use_buf(ctx0, 0);
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
{
wstate.use_buf(ctx0, 0);
- cur = ggml_norm(ctx0, inpFF);
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
wstate.use_buf(ctx0, 1);
{
wstate.use_buf(ctx0, 0);
- cur = ggml_norm(ctx0, cur);
+ cur = ggml_norm(ctx0, cur, hparams.eps);
wstate.use_buf(ctx0, 1);
{
wstate.use_buf(ctx0, 0);
- cur = ggml_norm(ctx0, inpL);
+ cur = ggml_norm(ctx0, inpL, hparams.eps);
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
{
wstate.use_buf(ctx0, 0);
- cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
+ cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here
// cur = ln_0_w*cur + ln_0_b
cur = ggml_add(ctx0,
{
wstate.use_buf(ctx0, 0);
- cur = ggml_norm(ctx0, inpFF);
+ cur = ggml_norm(ctx0, inpFF, hparams.eps);
wstate.use_buf(ctx0, 1);
{
wstate.use_buf(ctx0, 0);
- cur = ggml_norm(ctx0, cur);
+ cur = ggml_norm(ctx0, cur, hparams.eps);
wstate.use_buf(ctx0, 1);
struct ggml_tensor * b);
// normalize along rows
- // TODO: eps is hardcoded to 1e-5 for now
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
- struct ggml_tensor * a);
+ struct ggml_tensor * a,
+ float eps);
GGML_API struct ggml_tensor * ggml_norm_inplace(
struct ggml_context * ctx,
- struct ggml_tensor * a);
+ struct ggml_tensor * a,
+ float eps);
GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
/*.parse_seq = */ {0},
/*.has_parse_seq = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
- /*.allocated_tensors = */ = {0},
+ /*.allocated_tensors = */ {0},
#endif
};
/*.parse_seq = */ {0},
/*.has_parse_seq = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
- /*.allocated_tensors = */ = {0},
+ /*.allocated_tensors = */ {0},
#endif
};
struct ggml_tensor * view_src = get_view_source(parent);
struct hash_node * view_src_hn = hash_get(ht, view_src);
view_src_hn->n_views -= 1;
- AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src->n_children, view_src->n_views);
+ AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
ggml_allocator_free_tensor(alloc, view_src);
}
static struct ggml_tensor * ggml_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
+ float eps,
bool inplace) {
bool is_node = false;
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- // TODO: maybe store epsilon here?
+ ggml_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_OP_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
- struct ggml_tensor * a) {
- return ggml_norm_impl(ctx, a, false);
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_norm_impl(ctx, a, eps, false);
}
struct ggml_tensor * ggml_norm_inplace(
struct ggml_context * ctx,
- struct ggml_tensor * a) {
- return ggml_norm_impl(ctx, a, true);
+ struct ggml_tensor * a,
+ float eps) {
+ return ggml_norm_impl(ctx, a, eps, true);
}
// ggml_rms_norm
};
struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+ ggml_set_op_params_i32(result, 0, stride);
+
result->op = GGML_OP_CONV_TRANSPOSE_2D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = b;
- result->src[2] = ggml_new_i32(ctx, stride);
return result;
}
GGML_TENSOR_UNARY_OP_LOCALS;
- const float eps = 1e-5f; // TODO: make this a parameter
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
- const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
return;
}
- const int32_t stride = ((const int32_t*)(opt0->data))[0];
+ const int32_t stride = ggml_get_op_params_i32(dst, 0);
// total patches in dst
const int np = ne2;
const int ip1 = MIN(ip0 + dp, np);
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
- ggml_fp16_t * const wdata_src = (ggml_fp16_t *) params->wdata + nk;
+ ggml_fp16_t * const wdata_src = wdata + nk;
for (int i2 = ip0; i2 < ip1; i2++) { // Cout
float * dst_data = (float *)((char *) dst->data + i2*nb2);
for (int i00 = 0; i00 < ne00; i00++) {
float v = 0;
ggml_vec_dot_f16(ne03, &v,
- (ggml_fp16_t *) wdata_src + i1n,
- (ggml_fp16_t *) wdata_kernel + i01*ne00*ne03 + i00*ne03);
-
+ wdata_src + i1n,
+ wdata_kernel + i01*ne00*ne03 + i00*ne03);
dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
}
}
} break;
case GGML_OP_CONV_TRANSPOSE_2D:
{
- ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
+ ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_POOL_1D:
{