]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add ggml_scale_bias (llama/14417)
authorXuan-Son Nguyen <redacted>
Wed, 9 Jul 2025 16:16:12 +0000 (18:16 +0200)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
* ggml : add ggml_scale_bias

* ggml_vec_mad1_f32

* add more simd

* add CUDA

* sycl

* vulkan

* cann (placeholder)

* opencl

* will this fix cpu?

* fix cuda

* suggestions from coderabbit

* fix cann compile error

* vDSP_vsmsa

* rm __ARM_FEATURE_SVE

* use memcpy for op params

* make code looks more consistent

* use scalar for __ARM_FEATURE_SVE

* add x param to ggml_vec_mad1_f32

14 files changed:
include/ggml.h
src/ggml-cann/ggml-cann.cpp
src/ggml-cpu/ops.cpp
src/ggml-cpu/vec.h
src/ggml-cuda/scale.cu
src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal
src/ggml-opencl/ggml-opencl.cpp
src/ggml-opencl/kernels/scale.cl
src/ggml-sycl/ggml-sycl.cpp
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/scale.comp
src/ggml.c
tests/test-backend-ops.cpp

index 8b0bec8738d833fc506e59e7c2dc4f9b681f2dcd..13044602fd866529587f115dd6486856e7d635b8 100644 (file)
@@ -1294,6 +1294,19 @@ extern "C" {
             struct ggml_tensor  * a,
             float                 s);
 
+    // x = s * a + b
+    GGML_API struct ggml_tensor * ggml_scale_bias(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 s,
+        float                 b);
+
+    GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 s,
+        float                 b);
+
     // b -> view(a,offset,nb1,nb2,3), return modified a
     GGML_API struct ggml_tensor * ggml_set(
             struct ggml_context * ctx,
index eae575cc040cdfd19f4575ec93d4efd533f99d11..ccb17eb072eb23ecf69766d4b4078f0110fa511a 100755 (executable)
@@ -2188,7 +2188,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
         case GGML_OP_MUL:
         case GGML_OP_DIV:
         case GGML_OP_RMS_NORM:
-        case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_CLAMP:
@@ -2210,6 +2209,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_COUNT_EQUAL:
             return true;
+        case GGML_OP_SCALE:
+            float bias;
+            memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
+            return bias == 0.0f; // TODO: support bias != 0.0f
         case GGML_OP_SOFT_MAX:
             // TODO: support broadcast
             // ref: https://github.com/ggml-org/llama.cpp/pull/14435
index aaeee614ab993b33438aa2291c7889fc55417bac..fd77e9a6abad570077e0a608964acb8c7cb9c4e6 100644 (file)
@@ -4643,9 +4643,11 @@ static void ggml_compute_forward_scale_f32(
     GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
 
-    // scale factor
-    float v;
-    memcpy(&v, dst->op_params, sizeof(float));
+    float s; // scale factor
+    float b; // bias
+
+    memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -4664,12 +4666,22 @@ static void ggml_compute_forward_scale_f32(
 
     const size_t nb1 = dst->nb[1];
 
-    for (int i1 = ir0; i1 < ir1; i1++) {
-        if (dst->data != src0->data) {
-            // src0 is same shape as dst => same indices
-            memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
+    if (b == 0.0f) {
+        for (int i1 = ir0; i1 < ir1; i1++) {
+            if (dst->data != src0->data) {
+                // src0 is same shape as dst => same indices
+                // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
+                memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
+            }
+            ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
+        }
+    } else {
+        for (int i1 = ir0; i1 < ir1; i1++) {
+            ggml_vec_mad1_f32(nc,
+                (float *) ((char *) dst->data  + i1*nb1),
+                (float *) ((char *) src0->data + i1*nb1),
+                s, b);
         }
-        ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
     }
 }
 
index 1f5857a23e35c06aac9934c25c09e26047034332..d18783a00a1a5790f80af40a9a8c032e190f607b 100644 (file)
@@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
 #endif
 }
 
+inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
+#if defined(GGML_USE_ACCELERATE)
+    vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
+#elif defined(GGML_SIMD)
+    #if defined(__ARM_FEATURE_SVE)
+        // scalar ; TODO: Write SVE code
+        for (int i = 0; i < n; ++i) {
+            y[i] = x[i]*s + b;
+        }
+    #else
+        const int np = (n & ~(GGML_F32_STEP - 1));
+
+        GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
+        GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
+
+        GGML_F32_VEC ay[GGML_F32_ARR];
+
+        for (int i = 0; i < np; i += GGML_F32_STEP) {
+            for (int j = 0; j < GGML_F32_ARR; j++) {
+                ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
+                ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
+
+                GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
+            }
+        }
+
+        // leftovers
+        for (int i = np; i < n; ++i) {
+            y[i] = x[i]*s + b;
+        }
+    #endif
+#else
+    // scalar
+    for (int i = 0; i < n; ++i) {
+        y[i] = x[i]*s + b;
+    }
+#endif
+}
+
 //inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          }
 inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 #if defined(GGML_USE_ACCELERATE)
index 1405e066e86a29a4e56c6bda59e4f43ab3f40b17..2ee9e588992f46cce94710d684b7199fed441907 100644 (file)
@@ -1,18 +1,18 @@
 #include "scale.cuh"
 
-static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= k) {
         return;
     }
 
-    dst[i] = scale * x[i];
+    dst[i] = scale * x[i] + bias;
 }
 
-static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
-    scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
+    scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
 }
 
 void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     float scale;
-    memcpy(&scale, dst->op_params, sizeof(float));
+    float bias;
+    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&bias,  (float *) dst->op_params + 1, sizeof(float));
 
-    scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
+    scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
 }
index 40fc315e82fd15d96019ee7cb15edb8fc180ffba..83a0739809a6e53b0ea1a4f92c07c4b47857a797 100644 (file)
@@ -2256,7 +2256,9 @@ static bool ggml_metal_encode_node(
                 GGML_ASSERT(ggml_is_contiguous(src0));
 
                 float scale;
-                memcpy(&scale, dst->op_params, sizeof(scale));
+                float bias;
+                memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
+                memcpy(&bias,  ((const int32_t *) dst->op_params) + 1, sizeof(float));
 
                 int64_t n = ggml_nelements(dst);
 
@@ -2273,6 +2275,7 @@ static bool ggml_metal_encode_node(
                 [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
                 [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
+                [encoder setBytes:&bias  length:sizeof(bias)  atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
index 22240bab472493178ea8312cb7ea934f3fb25321..239ec31fbcb5892cc0b694f99f6f3a24ef35e517 100644 (file)
@@ -1014,16 +1014,18 @@ kernel void kernel_scale(
         device const float * src0,
         device       float * dst,
         constant     float & scale,
+        constant     float & bias,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * scale;
+    dst[tpig] = src0[tpig] * scale + bias;
 }
 
 kernel void kernel_scale_4(
         device const float4 * src0,
         device       float4 * dst,
         constant     float  & scale,
+        constant     float  & bias,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] * scale;
+    dst[tpig] = src0[tpig] * scale + bias;
 }
 
 kernel void kernel_clamp(
index a9fc039038705bda0cbc28c05ba68a9ed9ad6566..43d8e5c72c93795ca731a461cf46273e476e7d74 100644 (file)
@@ -5587,7 +5587,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
     ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
 
     float scale;
-    memcpy(&scale, dst->op_params, sizeof(scale));
+    float bias;
+    memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));
+    memcpy(&bias,  ((int32_t *) dst->op_params) + 1, sizeof(float));
 
     ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
     ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
@@ -5602,6 +5604,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
     CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
     CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
     CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float),    &scale));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float),    &bias));
 
     int n = ggml_nelements(dst)/4;
 
index 8cfd518fa5a3ef5b566558f9564d48d4f8773b24..aeca8a456e4fe30111ea6c084b25d1a14ae6f3a4 100644 (file)
@@ -8,9 +8,10 @@ kernel void kernel_scale(
         ulong offset0,
         global float4 * dst,
         ulong offsetd,
-        float scale
+        float scale,
+        float bias
 ) {
     src0 = (global float4*)((global char*)src0 + offset0);
     dst = (global float4*)((global char*)dst + offsetd);
-    dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
+    dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
 }
index 21c81e99a19aa560ef4b5c7f2b87224c2d556225..cd15bbdb29fa2bbc50d6119c415aa0577060ed14 100644 (file)
@@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
 }
 
-static void scale_f32(const float * x, float * dst, const float scale, const int k,
+static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
                       const sycl::nd_item<3> &item_ct1) {
     const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
                   item_ct1.get_local_id(2);
@@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
         return;
     }
 
-    dst[i] = scale * x[i];
+    dst[i] = scale * x[i] + bias;
 }
 
 
@@ -1842,7 +1842,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
 
 
 
-static void scale_f32_sycl(const float *x, float *dst, const float scale,
+static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
                            const int k, queue_ptr stream) {
     const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
     stream->parallel_for(
@@ -1850,7 +1850,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
                               sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
                           sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
         [=](sycl::nd_item<3> item_ct1) {
-            scale_f32(x, dst, scale, k, item_ct1);
+            scale_f32(x, dst, scale, bias, k, item_ct1);
         });
 }
 
@@ -2319,9 +2319,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
     float *       dst_dd  = static_cast<float *>(dst->data);
 
     float scale;
-    memcpy(&scale, dst->op_params, sizeof(float));
+    float bias;
+    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&bias,  (float *) dst->op_params + 1, sizeof(float));
 
-    scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
+    scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
     /*
     DPCT1010:87: SYCL uses exceptions to report errors and does not use the
     error codes. The call was replaced with 0. You need to rewrite this code.
index 2245a655498c5f147d6a2835f0d59ab4d72a0278..c36e1a6d3bfc26f91028d7b4b5fbbe1f01945a7f 100644 (file)
@@ -7508,7 +7508,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
-        op_params[0], 0.0f,
+        op_params[0], op_params[1],
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
     }, dryrun);
 }
index 4663428dee0a2878987a2eda364ebec1c66b5169..f10b0a02b5076675fc3510266ecd310592422020 100644 (file)
@@ -18,7 +18,7 @@ void main() {
             continue;
         }
 
-        data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
+        data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
         idx += num_threads;
     }
 }
index e2d9d616a596cc86bfa1adbc00ad9e46465d792f..c8d1f144d0379e6518648c1be0218b32ea766c42 100644 (file)
@@ -3061,12 +3061,14 @@ static struct ggml_tensor * ggml_scale_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         float                 s,
+        float                 b,
         bool                  inplace) {
     GGML_ASSERT(ggml_is_padded_1d(a));
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    ggml_set_op_params(result, &s, sizeof(s));
+    float params[2] = { s, b };
+    ggml_set_op_params(result, &params, sizeof(params));
 
     result->op     = GGML_OP_SCALE;
     result->src[0] = a;
@@ -3078,14 +3080,30 @@ struct ggml_tensor * ggml_scale(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         float                 s) {
-    return ggml_scale_impl(ctx, a, s, false);
+    return ggml_scale_impl(ctx, a, s, 0.0, false);
 }
 
 struct ggml_tensor * ggml_scale_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         float                 s) {
-    return ggml_scale_impl(ctx, a, s, true);
+    return ggml_scale_impl(ctx, a, s, 0.0, true);
+}
+
+struct ggml_tensor * ggml_scale_bias(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 s,
+        float                 b) {
+    return ggml_scale_impl(ctx, a, s, b, false);
+}
+
+struct ggml_tensor * ggml_scale_bias_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float                 s,
+        float                 b) {
+    return ggml_scale_impl(ctx, a, s, b, true);
 }
 
 // ggml_set
@@ -5769,7 +5787,7 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_MEAN: {
             if (src0_needs_grads) {
-                ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
+                ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
             }
         } break;
         case GGML_OP_REPEAT: {
@@ -5846,7 +5864,7 @@ static void ggml_compute_backward(
             if (src0_needs_grads) {
                 float s;
                 memcpy(&s, tensor->op_params, sizeof(float));
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
             }
         } break;
         case GGML_OP_SET: {
index b54bcc8a35e640b27b086bf340cb927b38c0675b..1d837b4322cfa7a5d21cd468bb1c0a5751d6f49d 100644 (file)
@@ -2368,22 +2368,24 @@ struct test_scale : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
     float scale;
+    float bias;
 
     std::string vars() override {
-        return VARS_TO_STR3(type, ne, scale);
+        return VARS_TO_STR4(type, ne, scale, bias);
     }
 
     test_scale(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 10, 10},
-            float scale = 2.0f)
-        : type(type), ne(ne), scale(scale) {}
+            float scale = 2.0f,
+            float bias = 0.0f)
+        : type(type), ne(ne), scale(scale), bias(bias) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_set_param(a);
         ggml_set_name(a, "a");
 
-        ggml_tensor * out = ggml_scale(ctx, a, scale);
+        ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias);
         ggml_set_name(out, "out");
 
         return out;
@@ -5044,6 +5046,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 
     test_cases.emplace_back(new test_add1());
     test_cases.emplace_back(new test_scale());
+    test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
     test_cases.emplace_back(new test_silu_back());
 
     for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {