]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml-quants : fix avx2 iq1_s vec_dot when compiled with gcc (llama/5742)
authorEngininja2 <redacted>
Tue, 27 Feb 2024 12:50:18 +0000 (06:50 -0600)
committerGeorgi Gerganov <redacted>
Wed, 28 Feb 2024 11:00:29 +0000 (13:00 +0200)
ggml-quants.c

index ce654f094da6905110216e3efd43f525a07647ce..73c3bb4123da5ab9f247d6e56050873785e8dab7 100644 (file)
@@ -10248,8 +10248,12 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const
 
     uint64_t aux64;
 
-    __m256i v_gindex;
-    const uint16_t * gindex = (const uint16_t *)&v_gindex;
+    typedef union m256i_uint16 {
+        __m256i reg;
+        uint16_t s[16];
+    } m256i_uint16_t;
+
+    m256i_uint16_t v_gindex;
 
     __m256 accum = _mm256_setzero_ps();
     for (int i = 0; i < nb; ++i) {
@@ -10264,13 +10268,13 @@ void ggml_vec_dot_iq1_s_q8_K  (int n, float * GGML_RESTRICT s, size_t bs, const
             memcpy(&aux64, sc, 8); sc += 8;
             const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
             const __m256i hbit = _mm256_cvtepu8_epi16(_mm_and_si128(qh, m8));
-            v_gindex = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
+            v_gindex.reg = _mm256_or_si256(_mm256_cvtepu8_epi16(ql), _mm256_slli_epi16(hbit, 5));
             const __m128i scales = _mm_or_si128(_mm_slli_epi16(_mm_and_si128(qh, m7), 1), m1);
 
             for (int i32 = 0; i32 < 4; ++i32) {
                 const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
-                const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
-                                                      iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
+                const __m256i q1b = _mm256_set_epi64x(iq1s_grid[v_gindex.s[4*i32+3]], iq1s_grid[v_gindex.s[4*i32+2]],
+                                                      iq1s_grid[v_gindex.s[4*i32+1]], iq1s_grid[v_gindex.s[4*i32+0]]);
                 const __m256i dot = mul_add_epi8(q1b, q8b);
                 const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
                 const __m256i p   = _mm256_madd_epi16(s16, dot);