]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: fix group_norm (llama/10496)
authorJeff Bolz <redacted>
Tue, 26 Nov 2024 15:45:05 +0000 (09:45 -0600)
committerGeorgi Gerganov <redacted>
Tue, 3 Dec 2024 19:05:37 +0000 (21:05 +0200)
Fix bad calculation of the end of the range. Add a backend test that
covers the bad case (taken from stable diffusion).

Fixes https://github.com/leejet/stable-diffusion.cpp/issues/439.

src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/group_norm.comp
tests/test-backend-ops.cpp

index 49527fdf40e947b9733f4112e4bb61326dd582ca..da1cfd24eac2253e21ad926e177d7e0ec1005dd0 100644 (file)
@@ -7157,7 +7157,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         const int32_t max_period = tensor->op_params[1];
         tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
     } else if (tensor->op == GGML_OP_POOL_2D) {
-        enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]);
+        enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
         const int32_t k0 = tensor->op_params[1];
         const int32_t k1 = tensor->op_params[2];
         const int32_t s0 = tensor->op_params[3];
index 5ad9b28daffaaa1764c3632c2fa1b4fd90f73884..b6a0d56454951ff5f293019bb235835bbd222377 100644 (file)
@@ -19,7 +19,7 @@ void main() {
 
     const uint tid = gl_LocalInvocationID.x;
     const uint start = gl_WorkGroupID.x * group_size + tid;
-    const uint end = start + group_size;
+    const uint end = (gl_WorkGroupID.x + 1) * group_size;
 
     tmp[tid] = 0.0f;
 
index caf64c0e9d00aaa90d1b897f23639c2f01641536..7dfeddce56dbf44ed1a1b9497c741a33f06b42f2 100644 (file)
@@ -3774,7 +3774,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_upscale());
     test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
     test_cases.emplace_back(new test_upscale_ext());
-    test_cases.emplace_back(new test_group_norm());
+    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
+    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
     test_cases.emplace_back(new test_acc());
     test_cases.emplace_back(new test_pad());
     test_cases.emplace_back(new test_arange());