]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CPU/CUDA: fix (GQA) mul mat back, add CUDA support (#11380)
authorJohannes Gäßler <redacted>
Fri, 24 Jan 2025 11:38:31 +0000 (12:38 +0100)
committerGitHub <redacted>
Fri, 24 Jan 2025 11:38:31 +0000 (12:38 +0100)
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ggml-cpu.cpp
ggml/src/ggml-cuda/binbcast.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/out-prod.cu
ggml/src/ggml.c
tests/test-backend-ops.cpp

index 0ed92b3ff4049a12fb065fc3aff5d203459d13fc..9e627da8c34d4d404fbe26fce39829059ba00ddd 100644 (file)
@@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32(
 
                     float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
                     float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
 
                     ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
                 }
@@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32(
 
                     float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
                     float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
 
                     ggml_vec_mad_f32(ne0, d, s0, *s1);
                 }
index 35a1c876c8631ff1bbd3219eb299aba67c16e634..2ccb4b472d6128cb7e46b059379cb0901affc2e4 100644 (file)
@@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
         case GGML_OP_IM2COL_BACK:
             return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
         case GGML_OP_OUT_PROD:
-            return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
+            return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
+                src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         default:
             return true;
     }
index c7b6be4e2905c1d927f0bc3860a0617b1ab16ce4..ce4b9cfb51bff26bb152a388e1ec6860739ea3af 100644 (file)
@@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
 
 template <typename T>
 static __global__ void k_repeat_back(
-    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2) {
+    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+    const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
 
-    const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
-    const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
-    const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
+    const int64_t tid0  = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
+    const int64_t tid1  = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
+    const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
+    const int64_t tid2  = tid23 % ne2;
+    const int64_t tid3  = tid23 / ne2;
 
     if (tid0 >= ne0) {
         return;
     }
 
     T sum = 0;
-    for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
-        for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
-            for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
-                sum += src[i2*ne01*ne00 + i1*ne00 + i0];
+    for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
+        for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+            for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+                for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+                    sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
+                }
             }
         }
     }
-    dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+    dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
 }
 
 template<float (*bin_op)(const float, const float)>
@@ -274,12 +279,14 @@ struct bin_bcast_cuda {
 
 template <typename T>
 static void repeat_back_cuda(
-    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
+    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+    const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
 
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
-    k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
+    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
+    k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
+        (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
 }
 
 template<class op>
@@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const ggml_tensor * src0 = dst->src[0];
 
     GGML_ASSERT(src0->type == dst->type);
-    GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_can_repeat(dst, src0));
 
     cudaStream_t stream = ctx.stream();
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    GGML_ASSERT(src0->ne[3] == 1);
+    GGML_TENSOR_UNARY_OP_LOCALS;
+
+    GGML_ASSERT(ne2*ne3 <= (1 << 15));
 
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    GGML_ASSERT(dst->ne[3] == 1);
+    const size_t ts = ggml_type_size(src0->type);
+    const size_t s00 = nb00 / ts;
+    const size_t s01 = nb01 / ts;
+    const size_t s02 = nb02 / ts;
+    const size_t s03 = nb03 / ts;
 
     switch (dst->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
             float       * dst_d  = (float       *) dst->data;
-            repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+            repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
         } break;
         default: {
             GGML_ASSERT(false);
index 7fd1fc85346f1cff40065ada241db9f003fdbeae..e602419bc57abbf7c6dc5f849b39d2c89f78623f 100644 (file)
@@ -3002,7 +3002,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
             } break;
         case GGML_OP_REPEAT_BACK:
-                return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
+                return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
         case GGML_OP_CONCAT:
             {
                 ggml_type src0_type = op->src[0]->type;
index 73e3e2c47f28a8801c38b730748432618f926f94..c9b2b699c6a55c217d3c99d67fd9fd4f1db07ef6 100644 (file)
@@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     CUBLAS_CHECK(cublasSetStream(handle, stream));
 
+    const int64_t lda = nb01 / sizeof(float);
+    const int64_t ldc = nb1  / sizeof(float);
+
     const bool src1_T = ggml_is_transposed(src1);
     const cublasOperation_t src1_cublas_op =  src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
     const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float);
@@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             CUBLAS_CHECK(
                 cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
                         ne0, ne1, ne01,
-                        &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
+                        &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
                                 src1_d +  i3      *s13 +  i2      *s12, ldb,
-                        &beta,  dst_d  +  i3      *s3  +  i2      *s2,  ne0));
+                        &beta,  dst_d  +  i3      *s3  +  i2      *s2,  ldc));
         }
     }
 }
index b1d0d4913f8e1ed025cf32c1cb97857d448f23ed..92c4294c500ebda07a9fea0df9ec5e9b6a851ca8 100644 (file)
@@ -5339,7 +5339,7 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_MUL: {
             if (src0_needs_grads) {
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
             }
             if (src1_needs_grads) {
                 struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
@@ -5431,21 +5431,25 @@ static void ggml_compute_backward(
             // src1.shape   [n,p,qq,rr]
 
             if (src0_needs_grads) {
-                struct ggml_tensor * s1_tg =
+                GGML_ASSERT(grad->ne[2] == src1->ne[2]);
+                GGML_ASSERT(grad->ne[3] == src1->ne[3]);
+                struct ggml_tensor * tmp =
                     ggml_out_prod(ctx, // [n,m,qq,rr]
                         src1,          // [n,p,qq,rr]
                         grad);         // [m,p,qq,rr]
-                const int64_t qq = s1_tg->ne[2];
-                const int64_t rr = s1_tg->ne[3];
-                const int64_t q1 = src0->ne[2];
-                const int64_t r1 = src0->ne[3];
-                const bool ne2_broadcasted = qq > q1;
-                const bool ne3_broadcasted = rr > r1;
-                if (ne2_broadcasted || ne3_broadcasted) {
-                    // sum broadcast repetitions of s1_tg into shape of src0
-                    s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+                if (!ggml_are_same_shape(tmp, src0)) {
+                    GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
+                    GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
+                    GGML_ASSERT(tmp->ne[3] == 1);
+
+                    const int64_t nr2 = tmp->ne[2] / src0->ne[2];
+                    const size_t nb2 = tmp->nb[2] * nr2;
+                    const size_t nb3 = tmp->nb[2];
+
+                    tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
+                    tmp = ggml_repeat_back(ctx, tmp, src0);
                 }
-                ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
+                ggml_add_or_set(ctx, cgraph, isrc0, tmp);
             }
             if (src1_needs_grads) {
                 ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5514,7 +5518,9 @@ static void ggml_compute_backward(
             if (src0_needs_grads) {
                 GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
                 GGML_ASSERT(ggml_is_contiguous(grad));
-                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+                GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
+                ggml_add_or_set(ctx, cgraph, isrc0,
+                    ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
             }
         } break;
         case GGML_OP_RESHAPE: {
index 381956a043b89904309e3973f43e973183a2aa37..46801640393020c83f3308139e611ca670d08114 100644 (file)
@@ -1302,6 +1302,59 @@ struct test_repeat : public test_case {
     }
 };
 
+// GGML_OP_REPEAT_BACK
+struct test_repeat_back : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const std::array<int, 4> nr;
+    const bool v; // whether src is a noncontiguous view
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, nr, v);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) * 2;
+    }
+
+    test_repeat_back(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {8, 6, 4, 2},
+            std::array<int, 4> nr = {2, 2, 2, 2},
+            bool v = false)
+        : type(type), ne(ne), nr(nr), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_set_name(src, "src");
+
+        if (v) {
+            GGML_ASSERT(ne[0] % 2 == 0);
+            GGML_ASSERT(ne[1] % 2 == 0);
+            GGML_ASSERT(ne[2] % 2 == 0);
+            GGML_ASSERT(ne[3] % 2 == 0);
+            GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
+            GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
+            GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
+            GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
+
+            const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
+            const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
+            const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
+            const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
+
+            src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
+        }
+
+        ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(target, "target");
+
+        ggml_tensor * out = ggml_repeat_back(ctx, src, target);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_DUP
 struct test_dup : public test_case {
     const ggml_type type;
@@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case {
         return 5e-4;
     }
 
+    int64_t grad_nmax() override {
+        return 20000;
+    }
+
     uint64_t op_flops(ggml_tensor * t) override {
         GGML_UNUSED(t);
         return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
@@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case {
 
             a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
             b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
-            ggml_set_param(ctx, a);
-            ggml_set_param(ctx, b);
+            if (!ggml_is_quantized(type_a)) {
+                if (bs[1] == 1 && nr[1] == 1) {
+                    ggml_set_param(ctx, a);
+                }
+                ggml_set_param(ctx, b);
+            }
             ggml_set_name(a, "a");
             ggml_set_name(b, "b");
 
@@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case {
         } else {
             a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0],       bs[1]);
             b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
-            ggml_set_param(ctx, a);
-            ggml_set_param(ctx, b);
+            if (!ggml_is_quantized(type_a)) {
+                if (bs[1] == 1 && nr[1] == 1) {
+                    ggml_set_param(ctx, a);
+                }
+                ggml_set_param(ctx, b);
+            }
             ggml_set_name(a, "a");
             ggml_set_name(b, "b");
         }
@@ -3798,6 +3863,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
     }
 
+    for (bool view : {false, true}) {
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
+    }
+
     test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
     test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
     test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
@@ -3919,21 +3994,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
             // test cases without permutation
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, { 1,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10,  1}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10, 10}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10, 10}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10, 10}, {1, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10, 10}, {2, 2}));
-
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 2}));
+
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
 
             // test cases with permutation
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));