case 0x44:
mc = 4;
nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+ gemm4xN<4>(m0, m, n0, n);
+#else
gemm<4, 4>(m0, m, n0, n);
+#endif
break;
case 0x43:
mc = 4;
nc = 3;
+#if defined(__AVX2__) && defined(__F16C__)
+ gemm4xN<3>(m0, m, n0, n);
+#else
gemm<4, 3>(m0, m, n0, n);
+#endif
break;
case 0x34:
mc = 3;
nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+ gemmMx4<3>(m0, m, n0, n);
+#else
gemm<3, 4>(m0, m, n0, n);
+#endif
break;
case 0x33:
mc = 3;
case 0x42:
mc = 4;
nc = 2;
+#if defined(__AVX2__) && defined(__F16C__)
+ gemm4xN<2>(m0, m, n0, n);
+#else
gemm<4, 2>(m0, m, n0, n);
+#endif
break;
case 0x24:
mc = 2;
nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+ gemmMx4<2>(m0, m, n0, n);
+#else
gemm<2, 4>(m0, m, n0, n);
+#endif
break;
#else
case 0x44:
case 0x42:
mc = 4;
nc = 2;
+#if defined(__AVX2__) && defined(__F16C__)
+ gemm4xN<2>(m0, m, n0, n);
+#else
gemm<4, 2>(m0, m, n0, n);
+#endif
break;
case 0x34:
case 0x24:
mc = 2;
nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+ gemmMx4<2>(m0, m, n0, n);
+#else
gemm<2, 4>(m0, m, n0, n);
+#endif
break;
case 0x33:
#endif
case 0x41:
mc = 4;
nc = 1;
+#if defined(__AVX2__) && defined(__F16C__)
+ gemm4xN<1>(m0, m, n0, n);
+#else
gemm<4, 1>(m0, m, n0, n);
+#endif
break;
case 0x22:
mc = 2;
case 0x14:
mc = 1;
nc = 4;
+#if defined(__AVX2__) && defined(__F16C__)
+ gemmMx4<1>(m0, m, n0, n);
+#else
gemm<1, 4>(m0, m, n0, n);
+#endif
break;
case 0x31:
mc = 3;
mnpack(m0, m, np, n);
}
+#if defined(__AVX2__) && defined(__F16C__)
+// Templated functions for gemm of dimensions 4xN
+ template <int RN>
+ NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / 4;
+ int64_t xtiles = (n - n0) / RN;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * 4;
+ int64_t jj = n0 + job % xtiles * RN;
+ __m256 Cv[RN][4] = {};
+ for (int64_t l = 0; l < k; ++l) {
+ uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
+ // Convert delta values for four blocks to float values
+ __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
+ __m256i avec0 = load(A + lda * (ii + 0) + l);
+ __m256i avec1 = load(A + lda * (ii + 1) + l);
+ __m256i avec2 = load(A + lda * (ii + 2) + l);
+ __m256i avec3 = load(A + lda * (ii + 3) + l);
+ for (int64_t j = 0; j < RN; ++j) {
+ __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
+ __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
+ dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
+ // Computation of dot product and multiplication with appropriate delta value products
+ Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
+ updot(_mm256_sign_epi8(avec0, avec0),
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
+ Cv[j][0]);
+ Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
+ updot(_mm256_sign_epi8(avec1, avec1),
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
+ Cv[j][1]);
+ Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
+ updot(_mm256_sign_epi8(avec2, avec2),
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
+ Cv[j][2]);
+ Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
+ updot(_mm256_sign_epi8(avec3, avec3),
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
+ Cv[j][3]);
+ }
+ }
+
+ for (int64_t j = 0; j < RN; ++j)
+ for (int64_t i = 0; i < 4; ++i)
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+ }
+ }
+
+ // Templated functions for gemm of dimensions Mx4
+ template <int RM>
+ NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+ int64_t ytiles = (m - m0) / RM;
+ int64_t xtiles = (n - n0) / 4;
+ int64_t tiles = xtiles * ytiles;
+ int64_t duty = (tiles + nth - 1) / nth;
+ int64_t start = duty * ith;
+ int64_t end = start + duty;
+ if (end > tiles)
+ end = tiles;
+ for (int64_t job = start; job < end; ++job) {
+ int64_t ii = m0 + job / xtiles * RM;
+ int64_t jj = n0 + job % xtiles * 4;
+ __m256 Cv[4][RM] = {};
+ for (int64_t l = 0; l < k; ++l) {
+ uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
+ // Convert delta values for four blocks to float values
+ __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
+ __m256i bvec0 = load(B + ldb * (jj + 0) + l);
+ __m256i bvec1 = load(B + ldb * (jj + 1) + l);
+ __m256i bvec2 = load(B + ldb * (jj + 2) + l);
+ __m256i bvec3 = load(B + ldb * (jj + 3) + l);
+ for (int64_t i = 0; i < RM; ++i) {
+ __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
+ __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
+ dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
+ // Computation of dot product and multiplication with appropriate delta value products
+ Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+ load(A + lda * (ii + i) + l)),
+ _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
+ Cv[0][i]);
+ Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+ load(A + lda * (ii + i) + l)),
+ _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
+ Cv[1][i]);
+ Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+ load(A + lda * (ii + i) + l)),
+ _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
+ Cv[2][i]);
+ Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
+ load(A + lda * (ii + i) + l)),
+ _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
+ Cv[3][i]);
+ }
+ }
+ for (int64_t j = 0; j < 4; ++j)
+ for (int64_t i = 0; i < RM; ++i)
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
+ }
+ }
+#endif
+
template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;