return s;
};
+ bool rte = device->float_controls_rte_fp16;
#define CREATE_BINARY(name, namemod, spec) \
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
- #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
+ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
CREATE_BINARY(add, , {0})
#undef CREATE_UNARY
#define CREATE_GLU(name) \
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
+ if (device->float_controls_rte_fp16) { \
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
+ } else { \
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
+ }
CREATE_GLU(geglu)
CREATE_GLU(reglu)
for (auto src0_f16 : {false, true}) {
for (auto src1_f16 : {false, true}) {
for (auto dst_f16 : {false, true}) {
- auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
- string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
+ for (auto rte : {false, true}) {
+ auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
+ string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
+ }
}
}
}
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
- string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
- string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+ for (auto rte : {false, true}) {
+ std::string suffix = rte ? "_rte" : "";
+ string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
+ }
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
std::remove(path.c_str());
}
}
+
+ std::string suffixes[2] = {"_f32", "_f16"};
for (const char *op : {"add", "sub", "mul", "div"}) {
- fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
- fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
- fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
- fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
+ fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
+ fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
+ std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
+ std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
+ for (uint32_t t0 = 0; t0 < 2; ++t0) {
+ if (t0 == 0) {
+ data += "{";
+ len += "{";
+ }
+ for (uint32_t t1 = 0; t1 < 2; ++t1) {
+ if (t1 == 0) {
+ data += "{";
+ len += "{";
+ }
+ for (uint32_t t2 = 0; t2 < 2; ++t2) {
+ if (t2 == 0) {
+ data += "{";
+ len += "{";
+ }
+ for (uint32_t rte = 0; rte < 2; ++rte) {
+ if (rte == 0) {
+ data += "{";
+ len += "{";
+ }
+ data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
+ len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
+ data += "_data,";
+ len += "_len,";
+ if (rte == 1) {
+ data += "}, ";
+ len += "}, ";
+ }
+ }
+ if (t2 == 1) {
+ data += "}, ";
+ len += "}, ";
+ }
+ }
+ if (t1 == 1) {
+ data += "}, ";
+ len += "}, ";
+ }
+ }
+ if (t0 == 1) {
+ data += "};\n";
+ len += "};\n";
+ }
+ }
+ fprintf(src, data.c_str());
+ fprintf(src, len.c_str());
}
fclose(hdr);
fclose(src);