vk_pipeline pipeline_div_norepeat[2][2][2];
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
- vk_pipeline pipeline_upscale_f32;
+ vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
vk_pipeline pipeline_sin_f32;
struct vk_op_upscale_push_constants {
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
+ uint32_t ne00; uint32_t ne01;
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
float sf0; float sf1; float sf2; float sf3;
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_upscale_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1);
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
}
return nullptr;
case GGML_OP_UPSCALE:
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
- return ctx->device->pipeline_upscale_f32;
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ int mode = ggml_get_op_params_i32(dst, 0);
+ switch (mode) {
+ case GGML_SCALE_MODE_NEAREST:
+ return ctx->device->pipeline_upscale_nearest_f32;
+ case GGML_SCALE_MODE_BILINEAR:
+ return ctx->device->pipeline_upscale_bilinear_f32;
+ case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS:
+ return ctx->device->pipeline_upscale_bilinear_ac_f32;
+ }
}
return nullptr;
case GGML_OP_SCALE:
static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
- const float sf0 = (float)dst->ne[0] / src0->ne[0];
- const float sf1 = (float)dst->ne[1] / src0->ne[1];
- const float sf2 = (float)dst->ne[2] / src0->ne[2];
- const float sf3 = (float)dst->ne[3] / src0->ne[3];
+ float sf0 = (float)dst->ne[0] / src0->ne[0];
+ float sf1 = (float)dst->ne[1] / src0->ne[1];
+ float sf2 = (float)dst->ne[2] / src0->ne[2];
+ float sf3 = (float)dst->ne[3] / src0->ne[3];
+
+ if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
+ sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
+ sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
+ }
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
(uint32_t)ggml_nelements(dst), 0, 0,
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1],
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
sf0, sf1, sf2, sf3,
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_UPSCALE:
- return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_ACC:
case GGML_OP_CONCAT:
case GGML_OP_SCALE:
case GGML_OP_PAD:
+ case GGML_OP_ROLL:
case GGML_OP_DIAG_MASK_INF:
- return true;
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ARGSORT:
layout (push_constant) uniform parameter
{
uint ne; uint a_offset; uint d_offset;
+ uint ne00; uint ne01;
uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13;
float sf0; float sf1; float sf2; float sf3;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag
+#define NEAREST 0
+#define BILINEAR 1
+#define ALIGN_CORNERS (1 << 8)
+
+layout (constant_id = 0) const uint scale_mode = 0;
+
+float fetch_nearest(uint i10, uint i11, uint i12, uint i13) {
+ const uint i00 = uint(i10 / p.sf0);
+ const uint i01 = uint(i11 / p.sf1);
+ const uint i02 = uint(i12 / p.sf2);
+ const uint i03 = uint(i13 / p.sf3);
+
+ return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00];
+}
+
+float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) {
+ const uint i02 = uint(i12 / p.sf2);
+ const uint i03 = uint(i13 / p.sf3);
+ const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02;
+
+ const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00];
+ const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00];
+ const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00];
+ const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00];
+
+ return
+ v00 * (1.0-d.x) * (1.0-d.y) +
+ v01 * d.x * (1.0-d.y) +
+ v10 * (1.0-d.x) * d.y +
+ v11 * d.x * d.y;
+}
+
+float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) {
+ const ivec2 ne0 = ivec2(p.ne00, p.ne01);
+
+ const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5;
+ const vec2 c0f = floor(c);
+ const vec2 d = c - c0f;
+ const ivec2 c0 = max(ivec2(c0f), 0);
+ const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1);
+
+ return fetch_bilinear(c0, c1, d, i12, i13);
+}
+
+float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) {
+ const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1);
+ const vec2 c0f = floor(c);
+ const vec2 d = c - c0f;
+ const ivec2 c0 = ivec2(c0f);
+ const ivec2 c1 = c0 + 1;
+
+ return fetch_bilinear(c0, c1, d, i12, i13);
+}
+
void main() {
const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12;
const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13;
- const uint i00 = uint(i10 / p.sf0);
- const uint i01 = uint(i11 / p.sf1);
- const uint i02 = uint(i12 / p.sf2);
- const uint i03 = uint(i13 / p.sf3);
+ float result;
+ switch (scale_mode) {
+ case NEAREST:
+ result = fetch_nearest(i10, i11, i12, i13);
+ break;
+ case BILINEAR:
+ result = interpolate_bilinear(i10, i11, i12, i13);
+ break;
+ case BILINEAR | ALIGN_CORNERS:
+ result = interpolate_bilinear_align_corners(i10, i11, i12, i13);
+ break;
+ }
- data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
+ data_d[p.d_offset + idx] = D_TYPE(result);
}