// QK = number of values after dequantization
// QR = QK / number of values before dequantization
+// QI = number of 32 bit integers before dequantization
#define QK4_0 32
#define QR4_0 2
+#define QI4_0 4
typedef struct {
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
#define QK4_1 32
#define QR4_1 2
+#define QI4_1 4
typedef struct {
half d; // delta
half m; // min
#define QK5_0 32
#define QR5_0 2
+#define QI5_0 4
typedef struct {
half d; // delta
uint8_t qh[4]; // 5-th bit of quants
#define QK5_1 32
#define QR5_1 2
+#define QI5_1 4
typedef struct {
half d; // delta
half m; // min
#define QK8_0 32
#define QR8_0 1
+#define QI8_0 8
typedef struct {
half d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
+#define QK8_1 32
+#define QR8_1 1
+#define QI8_1 8
+typedef struct {
+ half d; // delta
+ half s; // unquantized sum
+ int8_t qs[QK8_0]; // quants
+} block_q8_1;
+static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
+
+typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
+
//================================= k-quants
#ifdef GGML_QKK_64
#define CUDA_SCALE_BLOCK_SIZE 256
#define CUDA_ROPE_BLOCK_SIZE 256
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
+#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_CUDA_DMMV_X
#define GGML_CUDA_DMMV_X 32
#endif
-#ifndef GGML_CUDA_DMMV_Y
-#define GGML_CUDA_DMMV_Y 1
+#ifndef GGML_CUDA_MMV_Y
+#define GGML_CUDA_MMV_Y 1
#endif
#ifndef K_QUANTS_PER_ITERATION
}
// sum up partial sums
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
#endif
// sum up partial sums and write back result
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
#endif
// sum up partial sums and write back result
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
#endif
// sum up partial sums and write back result
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
#endif
// sum up partial sums and write back result
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
#endif
// sum up partial sums and write back result
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
v.y = x[ib + iqs + 1];
}
+static __global__ void quantize_q8_1(const float * x, void * vy, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ block_q8_1 * y = (block_q8_1 *) vy;
+
+ const int ib = i / QK8_0; // block index
+ const int iqs = i % QK8_0; // quant index
+
+ const float xi = x[i];
+ float amax = fabsf(xi);
+ float sum = xi;
+
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
+ sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
+ }
+
+ const float d = amax / 127;
+ const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+
+ y[ib].qs[iqs] = q;
+
+ if (iqs > 0) {
+ return;
+ }
+
+ y[ib].d = d;
+ y[ib].s = sum;
+}
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_block(const void * vx, float * y, const int k) {
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
y[iybs + iqs + y_offset] = v.y;
}
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+ const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
+
+ int vi;
+ memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
+ const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+ const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
+
+ const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d);
+
+ // subtract 8 from each quantized value
+ const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
+ const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
+
+ // SIMD dot product of quantized values
+ int sumi = __dp4a(vi0, ui0, 0);
+ sumi = __dp4a(vi1, ui1, sumi);
+
+ return sumi*d;
+#else
+ return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+ const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
+
+ const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
+ const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+ const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
+
+ const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d);
+ const float m = bq4_1->m;
+ const float s = bq8_1->s;
+
+ const int vi0 = (vi >> 0) & 0x0F0F0F0F;
+ const int vi1 = (vi >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ int sumi = __dp4a(vi0, ui0, 0);
+ sumi = __dp4a(vi1, ui1, sumi);
+
+ return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
+#else
+ return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+ const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
+
+ int qs;
+ memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
+ const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2);
+ const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2);
+ const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+ const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]);
+
+ const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d);
+
+ int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
+ vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
+ vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
+ vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
+ vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
+ vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values
+ int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
+
+ int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
+ vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
+ vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
+ vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
+ vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
+ vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values
+ sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
+
+ return sumi*d;
+#else
+ return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+ const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
+
+ const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
+ const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2);
+ const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2);
+ const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+ const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
+
+ const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
+ const float m = bq5_1->m;
+ const float s = bq8_1->s;
+
+ int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
+ vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
+ vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
+ vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
+ vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
+ int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
+
+ int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
+ vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
+ vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
+ vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
+ vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
+ sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
+
+ return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
+#else
+ return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
+#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+ const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
+
+ int vi;
+ memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
+ const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
+
+ const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d);
+
+ // SIMD dot product of quantized values
+ int sumi = __dp4a(vi, ui, 0);
+
+ return sumi*d;
+#else
+ return 0.0f; // only to satisfy the compiler
+#endif // __CUDA_ARCH__ >= 600
+}
+
+template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
+static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
+ const int row = blockIdx.y*blockDim.y + threadIdx.y;
+
+ if (row >= nrows) {
+ return;
+ }
+
+ const int blocks_per_row = ncols / qk;
+ const int blocks_per_warp = WARP_SIZE / qi;
+
+// partial sum for each thread
+ float tmp = 0.0f;
+
+ const block_q_t * x = (const block_q_t *) vx;
+ const block_q8_1 * y = (const block_q8_1 *) vy;
+
+ for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
+ const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index
+
+ const int iby = i + threadIdx.x / qi; // y block index
+
+ const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int
+
+ tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+ }
+
+ if (threadIdx.x == 0) {
+ dst[row] = tmp;
+ }
+}
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
}
// sum up partial sums and write back result
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
const int idst = channel*nrows_dst + row_dst;
// sum up partial sums and write back result
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
// sum up partial sums and write back result
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
// sum up partial sums
- __syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}
+static void quantize_row_q8_1_cuda(const float * x, void * vy, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+ quantize_q8_1<<<num_blocks, CUDA_QUANTIZE_BLOCK_SIZE, 0, stream>>>(x, vy, k);
+}
+
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
+static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, vec_dot_q4_0_q8_1>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, vec_dot_q4_1_q8_1>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, vec_dot_q5_0_q8_1>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, vec_dot_q5_1_q8_1>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
+static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
+ const dim3 block_nums(1, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
+ mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, vec_dot_q8_0_q8_1>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
+}
+
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
- const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
+ const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_f16>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static int g_device_count = -1;
static int g_main_device = 0;
+static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
for (int id = 0; id < g_device_count; ++id) {
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
- fprintf(stderr, " Device %d: %s\n", id, prop.name);
+ fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
+
g_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
+
+ g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
}
for (int id = 0; id < g_device_count; ++id) {
g_tensor_split[id] /= total_vram;
(void) i1;
}
-inline void ggml_cuda_op_dequantize_mul_mat_vec(
+inline void ggml_cuda_op_mul_mat_vec(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
cudaStream_t & cudaStream_main){
const int64_t ne00 = src0->ne[0];
const int64_t nrows = i01_high - i01_low;
-// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
-#ifdef GGML_CUDA_DMMV_F16
- size_t ash;
- dfloat * src1_dfloat = nullptr; // dfloat == half
+#ifdef GGML_CUDA_FORCE_DMMV
+ const bool use_mul_mat_vec_q = false;
+#else
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
- bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
- src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
- src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+ const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 ||
+ src0->type == GGML_TYPE_Q4_1 ||
+ src0->type == GGML_TYPE_Q5_0 ||
+ src0->type == GGML_TYPE_Q5_1 ||
+ src0->type == GGML_TYPE_Q8_0;
- if (src1_convert_f16) {
- src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
- ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
- ne00, 1, sizeof(float), 0, 0,
- ne00, 1, sizeof(half), 0, 0, cudaStream_main);
- }
+ // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6.
+ // However, they have bad performance with Pascal cards.
+ // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q.
+ const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented;
+#endif
+
+ if (use_mul_mat_vec_q) {
+ size_t as;
+ void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as);
+ quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, cudaStream_main);
+
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q4_1:
+ mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q5_0:
+ mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q5_1:
+ mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q8_0:
+ mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+
+ ggml_cuda_pool_free(src1_q8_1, as);
+ } else {
+ // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+#ifdef GGML_CUDA_DMMV_F16
+ size_t ash;
+ dfloat * src1_dfloat = nullptr; // dfloat == half
+
+ bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+ src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+ if (src1_convert_f16) {
+ src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+ ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
+ ne00, 1, sizeof(float), 0, 0,
+ ne00, 1, sizeof(half), 0, 0, cudaStream_main);
+ }
#else
- dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
+ dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
#endif // GGML_CUDA_DMMV_F16
- switch (src0->type) {
- case GGML_TYPE_Q4_0:
- dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q4_1:
- dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q5_0:
- dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q5_1:
- dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q8_0:
- dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q2_K:
- dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q3_K:
- dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q4_K:
- dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q5_K:
- dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_Q6_K:
- dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- case GGML_TYPE_F16:
- convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
- break;
- default:
- GGML_ASSERT(false);
- break;
- }
+ switch (src0->type) {
+ case GGML_TYPE_Q4_0:
+ dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q4_1:
+ dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q5_0:
+ dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q5_1:
+ dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q8_0:
+ dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q2_K:
+ dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q3_K:
+ dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q4_K:
+ dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q5_K:
+ dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_Q6_K:
+ dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ case GGML_TYPE_F16:
+ convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
+ break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
#ifdef GGML_CUDA_DMMV_F16
- if (src1_convert_f16) {
- ggml_cuda_pool_free(src1_dfloat, ash);
- }
+ if (src1_convert_f16) {
+ ggml_cuda_pool_free(src1_dfloat, ash);
+ }
#endif // GGML_CUDA_DMMV_F16
+ }
(void) src1;
(void) dst;
}else if (src0->type == GGML_TYPE_F32) {
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
- if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false);
+ if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false);
} else {
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false);
}
}
void ggml_cuda_free_data(struct ggml_tensor * tensor) {
- if (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) {
+ if (!tensor || (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) ) {
return;
}
return GGML_FP32_TO_FP16(x);
}
-void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
- for (size_t i = 0; i < n; i++) {
+void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n) {
+ for (int i = 0; i < n; i++) {
y[i] = GGML_FP16_TO_FP32(x[i]);
}
}
-void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
- size_t i = 0;
+void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) {
+ int i = 0;
#if defined(__F16C__)
for (; i + 7 < n; i += 8) {
__m256 x_vec = _mm256_loadu_ps(x + i);
}
}
+static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
+static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
+static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
+ [GGML_TYPE_F32] = {
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
+ .vec_dot_type = GGML_TYPE_F32,
+ },
+ [GGML_TYPE_F16] = {
+ .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
+ .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+ .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
+ .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
+ .vec_dot_type = GGML_TYPE_F16,
+ },
[GGML_TYPE_Q4_0] = {
- .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_0,
- .quantize_row_q = quantize_row_q4_0,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
- .quantize_row_q_dot = quantize_row_q8_0,
- .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
+ .to_float = (ggml_to_float_t) dequantize_row_q4_0,
+ .from_float = quantize_row_q4_0,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
+ .vec_dot = ggml_vec_dot_q4_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
},
[GGML_TYPE_Q4_1] = {
- .dequantize_row_q = (dequantize_row_q_t)dequantize_row_q4_1,
- .quantize_row_q = quantize_row_q4_1,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
- .quantize_row_q_dot = quantize_row_q8_1,
- .vec_dot_q = ggml_vec_dot_q4_1_q8_1,
+ .to_float = (ggml_to_float_t) dequantize_row_q4_1,
+ .from_float = quantize_row_q4_1,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
+ .vec_dot = ggml_vec_dot_q4_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1,
},
[GGML_TYPE_Q5_0] = {
- .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_0,
- .quantize_row_q = quantize_row_q5_0,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_0_reference,
- .quantize_row_q_dot = quantize_row_q8_0,
- .vec_dot_q = ggml_vec_dot_q5_0_q8_0,
+ .to_float = (ggml_to_float_t) dequantize_row_q5_0,
+ .from_float = quantize_row_q5_0,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
+ .vec_dot = ggml_vec_dot_q5_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
},
[GGML_TYPE_Q5_1] = {
- .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_1,
- .quantize_row_q = quantize_row_q5_1,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_1_reference,
- .quantize_row_q_dot = quantize_row_q8_1,
- .vec_dot_q = ggml_vec_dot_q5_1_q8_1,
+ .to_float = (ggml_to_float_t) dequantize_row_q5_1,
+ .from_float = quantize_row_q5_1,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
+ .vec_dot = ggml_vec_dot_q5_1_q8_1,
.vec_dot_type = GGML_TYPE_Q8_1,
},
[GGML_TYPE_Q8_0] = {
- .dequantize_row_q = dequantize_row_q8_0,
- .quantize_row_q = quantize_row_q8_0,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
- .quantize_row_q_dot = quantize_row_q8_0,
- .vec_dot_q = ggml_vec_dot_q8_0_q8_0,
+ .to_float = dequantize_row_q8_0,
+ .from_float = quantize_row_q8_0,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
+ .vec_dot = ggml_vec_dot_q8_0_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
},
[GGML_TYPE_Q8_1] = {
- .dequantize_row_q = NULL, // TODO
- .quantize_row_q = quantize_row_q8_1,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_1_reference,
- .quantize_row_q_dot = quantize_row_q8_1,
- .vec_dot_q = NULL, // TODO
+ .from_float = quantize_row_q8_1,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
.vec_dot_type = GGML_TYPE_Q8_1,
},
#ifdef GGML_USE_K_QUANTS
[GGML_TYPE_Q2_K] = {
- .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q2_K,
- .quantize_row_q = quantize_row_q2_K,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q2_K_reference,
- .quantize_row_q_dot = quantize_row_q8_K,
- .vec_dot_q = ggml_vec_dot_q2_K_q8_K,
+ .to_float = (ggml_to_float_t) dequantize_row_q2_K,
+ .from_float = quantize_row_q2_K,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
+ .vec_dot = ggml_vec_dot_q2_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_Q3_K] = {
- .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q3_K,
- .quantize_row_q = quantize_row_q3_K,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q3_K_reference,
- .quantize_row_q_dot = quantize_row_q8_K,
- .vec_dot_q = ggml_vec_dot_q3_K_q8_K,
+ .to_float = (ggml_to_float_t) dequantize_row_q3_K,
+ .from_float = quantize_row_q3_K,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
+ .vec_dot = ggml_vec_dot_q3_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_Q4_K] = {
- .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_K,
- .quantize_row_q = quantize_row_q4_K,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_K_reference,
- .quantize_row_q_dot = quantize_row_q8_K,
- .vec_dot_q = ggml_vec_dot_q4_K_q8_K,
+ .to_float = (ggml_to_float_t) dequantize_row_q4_K,
+ .from_float = quantize_row_q4_K,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
+ .vec_dot = ggml_vec_dot_q4_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_Q5_K] = {
- .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q5_K,
- .quantize_row_q = quantize_row_q5_K,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q5_K_reference,
- .quantize_row_q_dot = quantize_row_q8_K,
- .vec_dot_q = ggml_vec_dot_q5_K_q8_K,
+ .to_float = (ggml_to_float_t) dequantize_row_q5_K,
+ .from_float = quantize_row_q5_K,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
+ .vec_dot = ggml_vec_dot_q5_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_Q6_K] = {
- .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q6_K,
- .quantize_row_q = quantize_row_q6_K,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q6_K_reference,
- .quantize_row_q_dot = quantize_row_q8_K,
- .vec_dot_q = ggml_vec_dot_q6_K_q8_K,
+ .to_float = (ggml_to_float_t) dequantize_row_q6_K,
+ .from_float = quantize_row_q6_K,
+ .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
+ .vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
+ [GGML_TYPE_Q8_K] = {
+ .from_float = quantize_row_q8_K,
+ }
#endif
};
// For internal test use
-quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
+ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type i) {
GGML_ASSERT(i < GGML_TYPE_COUNT);
- return quantize_fns[i];
+ return type_traits[i];
}
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
-inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
+static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
#ifdef GGML_SIMD
float sumf = 0.0f;
const int np = (n & ~(GGML_F32_STEP - 1));
*s = sumf;
}
-inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
+static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
ggml_float sumf = 0.0;
#if defined(GGML_SIMD)
id += ne00 * (ne01 - ir1);
}
}
- } else if (ggml_is_quantized(dst->type)) {
- quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+ } else if (type_traits[dst->type].from_float) {
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
size_t id = 0;
id += rs * (ne01 - ir1);
}
}
- } else if (dst->type == GGML_TYPE_F16) {
- size_t id = 0;
- ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
-
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = 0; i02 < ne02; i02++) {
- id += ne00 * ir0;
- for (int i01 = ir0; i01 < ir1; i01++) {
- for (int i00 = 0; i00 < ne00; i00++) {
- const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
-
- dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
- id++;
- }
- }
- id += ne00 * (ne01 - ir1);
- }
- }
- } else if (ggml_is_quantized(dst->type)) {
- quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
+ } else if (type_traits[dst->type].from_float) {
+ ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float;
size_t id = 0;
size_t rs = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
const int nth = params->nth;
const enum ggml_type type = src0->type;
- dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
- quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+ ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
GGML_TENSOR_UNARY_OP_LOCALS;
const enum ggml_type type = src0->type;
- dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
- quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
+ ggml_from_float_t const quantize_row_q = type_traits[type].from_float;
// we don't support permuted src0
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
}
#endif
-static void ggml_compute_forward_mul_mat_f32(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * src0,
- const struct ggml_tensor * src1,
- struct ggml_tensor * dst) {
- int64_t t0 = ggml_perf_time_us();
- UNUSED(t0);
-
- GGML_TENSOR_BINARY_OP_LOCALS;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- assert(ne02 == ne12);
- assert(ne03 == ne13);
- assert(ne2 == ne12);
- assert(ne3 == ne13);
-
- // we don't support permuted src0 or src1
- assert(nb00 == sizeof(float));
- assert(nb10 == sizeof(float));
-
- // dst cannot be transposed or permuted
- assert(nb0 == sizeof(float));
- assert(nb0 <= nb1);
- assert(nb1 <= nb2);
- assert(nb2 <= nb3);
-
- assert(ne0 == ne01);
- assert(ne1 == ne11);
- assert(ne2 == ne02);
- assert(ne3 == ne03);
-
- // nb01 >= nb00 - src0 is not transposed
- // compute by src0 rows
-
-#if defined(GGML_USE_CLBLAST)
- if (ggml_cl_can_mul_mat(src0, src1, dst)) {
- if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
- ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
- }
- return;
- }
-#endif
-
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
- if (params->ith != 0) {
- return;
- }
-
- if (params->type == GGML_TASK_INIT) {
- return;
- }
-
- if (params->type == GGML_TASK_FINALIZE) {
- return;
- }
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-
- cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
- ne11, ne01, ne10,
- 1.0f, y, ne10,
- x, ne00,
- 0.0f, d, ne01);
- }
- }
- //printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
-
- return;
- }
-#endif
-
- if (params->type == GGML_TASK_INIT) {
- return;
- }
-
- if (params->type == GGML_TASK_FINALIZE) {
- return;
- }
-
- // parallelize by src0 rows using ggml_vec_dot_f32
-
- // total rows in src0
- const int nr = ne01*ne02*ne03;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 indices
- const int i03 = ir/(ne02*ne01);
- const int i02 = (ir - i03*ne02*ne01)/ne01;
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- for (int64_t ic = 0; ic < ne11; ++ic) {
- // src1 indices
- const int i13 = i03;
- const int i12 = i02;
- const int i11 = ic;
-
- // dst indices
- const int i0 = i01;
- const int i1 = i11;
- const int i2 = i02;
- const int i3 = i03;
-
- ggml_vec_dot_f32(ne00,
- (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
- (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
- }
- }
-
- //int64_t t1 = ggml_perf_time_us();
- //static int64_t acc = 0;
- //acc += t1 - t0;
- //if (t1 - t0 > 10) {
- // printf("\n");
- // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
- // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
- // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
- // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13);
-
- // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
- //}
-}
-
-static void ggml_compute_forward_mul_mat_f16_f32(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * src0,
- const struct ggml_tensor * src1,
- struct ggml_tensor * dst) {
- int64_t t0 = ggml_perf_time_us();
- UNUSED(t0);
-
- GGML_TENSOR_BINARY_OP_LOCALS;
-
- //const int64_t ne = ne0*ne1*ne2*ne3;
-
- const int ith = params->ith;
- const int nth = params->nth;
-
- GGML_ASSERT(ne02 == ne12);
- GGML_ASSERT(ne03 == ne13);
- GGML_ASSERT(ne2 == ne12);
- GGML_ASSERT(ne3 == ne13);
-
- // TODO: we don't support permuted src0
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
-
- // dst cannot be transposed or permuted
- GGML_ASSERT(nb0 == sizeof(float));
- GGML_ASSERT(nb0 <= nb1);
- GGML_ASSERT(nb1 <= nb2);
- GGML_ASSERT(nb2 <= nb3);
-
- GGML_ASSERT(ne0 == ne01);
- GGML_ASSERT(ne1 == ne11);
- GGML_ASSERT(ne2 == ne02);
- GGML_ASSERT(ne3 == ne03);
-
- // nb01 >= nb00 - src0 is not transposed
- // compute by src0 rows
-
-#if defined(GGML_USE_CLBLAST)
- if (ggml_cl_can_mul_mat(src0, src1, dst)) {
- if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
- ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
- }
- return;
- }
-#endif
-
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
- GGML_ASSERT(nb10 == sizeof(float));
-
- if (params->ith != 0) {
- return;
- }
-
- if (params->type == GGML_TASK_INIT) {
- return;
- }
-
- if (params->type == GGML_TASK_FINALIZE) {
- return;
- }
-
- for (int64_t i03 = 0; i03 < ne03; i03++) {
- for (int64_t i02 = 0; i02 < ne02; i02++) {
- float * const wdata = params->wdata;
- {
- size_t id = 0;
- for (int64_t i01 = 0; i01 < ne01; ++i01) {
- for (int64_t i00 = 0; i00 < ne00; ++i00) {
- wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
- }
- }
-
- assert(id*sizeof(float) <= params->wsize);
- }
-
- const float * x = wdata;
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
-
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-
- // zT = y * xT
- cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
- ne11, ne01, ne10,
- 1.0f, y, ne10,
- x, ne00,
- 0.0f, d, ne01);
- }
- }
-
- /*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
-
- return;
- }
-#endif
-
- if (params->type == GGML_TASK_INIT) {
- ggml_fp16_t * const wdata = params->wdata;
-
- size_t id = 0;
- for (int64_t i13 = 0; i13 < ne13; ++i13) {
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
- for (int64_t i11 = 0; i11 < ne11; ++i11) {
- for (int64_t i10 = 0; i10 < ne10; ++i10) {
- wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
- }
- }
- }
- }
-
- GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
-
- return;
- }
-
- if (params->type == GGML_TASK_FINALIZE) {
- return;
- }
-
- // fp16 -> half the size, so divide by 2
- // TODO: do not support transposed src1
- assert(nb10/2 == sizeof(ggml_fp16_t));
-
- // parallelize by src0 rows using ggml_vec_dot_f16
-
- // total rows in src0
- const int nr = ne01*ne02*ne03;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- ggml_fp16_t * wdata = params->wdata;
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 indices
- const int i03 = ir/(ne02*ne01);
- const int i02 = (ir - i03*ne02*ne01)/ne01;
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int i13 = i03;
- const int i12 = i02;
-
- const int i0 = i01;
- const int i2 = i02;
- const int i3 = i03;
-
- ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
- ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00;
-
- float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
-
- for (int64_t ic = 0; ic < ne11; ++ic) {
- ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
- }
- }
-
- //int64_t t1 = ggml_time_us();
- //static int64_t acc = 0;
- //acc += t1 - t0;
- //if (t1 - t0 > 10) {
- // printf("\n");
- // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
- // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
- // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
-
- // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
- //}
-}
-
-static void ggml_compute_forward_mul_mat_q_f32(
+static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
GGML_ASSERT(ne3 == ne13);
const enum ggml_type type = src0->type;
- quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
- vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
- enum ggml_type const vec_dot_type = quantize_fns[type].vec_dot_type;
+
+ ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
+ enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
+ ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
return;
}
- float * const wdata = params->wdata;
- dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
-
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
+ const void * x = (char *) src0->data + i03*nb03 + i02*nb02;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
- {
+ if (type != GGML_TYPE_F32) {
+ float * const wdata = params->wdata;
+ ggml_to_float_t const to_float = type_traits[type].to_float;
+
size_t id = 0;
for (int64_t i01 = 0; i01 < ne01; ++i01) {
- dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
+ to_float((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
id += ne00;
}
assert(id*sizeof(float) <= params->wsize);
+ x = wdata;
}
- const float * x = wdata;
-
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
ne11, ne01, ne10,
1.0f, y, ne10,
#endif
if (params->type == GGML_TASK_INIT) {
- char * wdata = params->wdata;
- const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
-
- for (int64_t i13 = 0; i13 < ne13; ++i13) {
- for (int64_t i12 = 0; i12 < ne12; ++i12) {
- for (int64_t i11 = 0; i11 < ne11; ++i11) {
- quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
- wdata += row_size;
+ if (src1->type != vec_dot_type) {
+ char * wdata = params->wdata;
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
+
+ for (int64_t i13 = 0; i13 < ne13; ++i13) {
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
+ from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+ wdata += row_size;
+ }
}
}
}
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- void * wdata = params->wdata;
+ void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
for (int ir = ir0; ir < ir1; ++ir) {
assert(ne00 % 32 == 0);
for (int64_t ic = 0; ic < ne11; ++ic) {
- vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
+ vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
}
}
//}
}
-static void ggml_compute_forward_mul_mat(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * src0,
- const struct ggml_tensor * src1,
- struct ggml_tensor * dst) {
- switch (src0->type) {
- case GGML_TYPE_Q4_0:
- case GGML_TYPE_Q4_1:
- case GGML_TYPE_Q5_0:
- case GGML_TYPE_Q5_1:
- case GGML_TYPE_Q8_0:
- case GGML_TYPE_Q8_1:
- case GGML_TYPE_Q2_K:
- case GGML_TYPE_Q3_K:
- case GGML_TYPE_Q4_K:
- case GGML_TYPE_Q5_K:
- case GGML_TYPE_Q6_K:
- {
- ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
- } break;
- case GGML_TYPE_F16:
- {
- ggml_compute_forward_mul_mat_f16_f32(params, src0, src1, dst);
- } break;
- case GGML_TYPE_F32:
- {
- ggml_compute_forward_mul_mat_f32(params, src0, src1, dst);
- } break;
- default:
- {
- GGML_ASSERT(false);
- } break;
- }
-}
// ggml_compute_forward_out_prod
const int nc = src0->ne[0];
const int nr = ggml_nelements(src1);
const enum ggml_type type = src0->type;
- dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
+ ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
assert( dst->ne[0] == nc);
assert( dst->ne[1] == nr);
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
size_t cur = 0;
+ const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
}
else
#endif
- if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
- node->n_tasks = 1; // TODO: this actually is doing nothing
- // the threads are still spinning
+ if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+ node->n_tasks = 1; // TODO: this actually is doing nothing
+ // the threads are still spinning
+ if (node->src0->type != GGML_TYPE_F32) {
// here we need memory just for single 2D matrix from src0
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
- } else {
- cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
- }
-#else
- cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
-#endif
- } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
- cur = 0;
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
- node->n_tasks = 1;
}
+ } else
#endif
- } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
- node->n_tasks = 1;
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
- } else
-#endif
- {
- const enum ggml_type type_q = quantize_fns[node->src0->type].vec_dot_type;
- cur = GGML_TYPE_SIZE[type_q]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[type_q];
- }
+ if (node->src1->type != vec_dot_type) {
+ cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
} else {
GGML_ASSERT(false);
}