uint32_t mode; // 0: default, 1: swapped, 2: split
float alpha; // for swiglu_oai
float limit;
+ uint32_t nb01;
+ uint32_t nb02;
+ uint32_t nb03;
+ uint32_t ne01;
+ uint32_t ne02;
+ uint32_t nb11;
+ uint32_t nb12;
+ uint32_t nb13;
+ uint32_t ne11;
+ uint32_t ne12;
};
struct vk_op_unary_push_constants {
} else {
device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
}
- vk::DeviceCreateInfo device_create_info;
+ vk::DeviceCreateInfo device_create_info{};
std::vector<const char *> device_extensions;
vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
#endif
device->name = GGML_VK_NAME + std::to_string(idx);
- device_create_info = {
- vk::DeviceCreateFlags(),
- device_queue_create_infos,
- {},
- device_extensions
- };
+ device_create_info
+ .setFlags(vk::DeviceCreateFlags())
+ .setQueueCreateInfos(device_queue_create_infos)
+ .setPEnabledExtensionNames(device_extensions);
device_create_info.setPNext(&device_features2);
device->device = device->physical_device.createDevice(device_create_info);
const float alpha = op_params_f[2];
const float limit = op_params_f[3];
- GGML_ASSERT(ggml_is_contiguous(src0));
-
if (!split) {
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
} else {
(uint32_t)dst->ne[0],
mode,
alpha,
- limit
+ limit,
+ (uint32_t)(src0->nb[1] / src0->nb[0]),
+ (uint32_t)(src0->nb[2] / src0->nb[0]),
+ (uint32_t)(src0->nb[3] / src0->nb[0]),
+ (uint32_t)src0->ne[1],
+ (uint32_t)src0->ne[2],
+ (uint32_t)(dst->nb[1] / dst->nb[0]),
+ (uint32_t)(dst->nb[2] / dst->nb[0]),
+ (uint32_t)(dst->nb[3] / dst->nb[0]),
+ (uint32_t)dst->ne[1],
+ (uint32_t)dst->ne[2]
});
}
case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
- return ggml_is_contiguous(op->src[0]) &&
- (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
(op->src[0]->type == op->type);
default:
const uint row = i / p.ne20;
const uint col = i - row * p.ne20;
+ const uint i3 = row / (p.ne01 * p.ne02);
+ const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01;
+ const uint i1 = row % p.ne01;
+ const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col;
+
+ const uint dst_i3 = row / (p.ne11 * p.ne12);
+ const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11;
+ const uint dst_i1 = row % p.ne11;
+ const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col;
+
if (p.mode == 0) {
// Default
const uint offset = p.ne00 / 2;
- const uint idx = row * p.ne00 + col;
+ const uint idx = src_idx;
- data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
+ data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
} else if (p.mode == 1) {
// Swapped
const uint offset = p.ne00 / 2;
- const uint idx = row * p.ne00 + col;
+ const uint idx = src_idx;
- data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
+ data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
} else {
// Split
- const uint idx = row * p.ne00 + col;
+ const uint idx = src_idx;
- data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
+ data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
}
}