]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
hexagon: dma optimizations (mostly fixing regressions) (llama/21137)
authorMax Krasnyansky <redacted>
Sun, 29 Mar 2026 13:40:13 +0000 (06:40 -0700)
committerGeorgi Gerganov <redacted>
Wed, 1 Apr 2026 13:00:26 +0000 (16:00 +0300)
* hex-fa: add simple dma cache for Mask

I noticed that we were refetch the mask rows over and over.
This simple cache avoids that.

* hex-dma: unset in-order desc bit which caused signficant perf regression

We don't rely on true in order processing of the DMA descriptors anywhere.
Turns out this mode caused significant regression of around 3-4 TPS during token gen.

* hex-rope: update comment to clarify that we don't need in-order DMA completions

src/ggml-hexagon/htp/flash-attn-ops.c
src/ggml-hexagon/htp/hex-dma.h
src/ggml-hexagon/htp/rope-ops.c

index 6dc978dd68a1f1286ecb2f9b56742ebfe42109ac..0c9bc785620f1ebfc409850f622506aaa32a30cb 100644 (file)
@@ -346,6 +346,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
 
     const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
 
+    dma_cache m_cache;
+    dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE);
+
     for (uint32_t ir = ir0; ir < ir1; ++ir) {
         const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
         const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
@@ -389,9 +392,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
             // Mask
             if (mask) {
                 const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
-                uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
                 // Mask is 1D contiguous for this row
-                dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
+                dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
             }
 
             // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
@@ -554,7 +556,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
                 // Mask
                 if (mask) {
                     const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
-                    dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
+                    dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
                 }
 
                 // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
@@ -684,7 +686,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
     octx->src0_spad.size_per_thread = size_q_block * 1;
     octx->src1_spad.size_per_thread = factx.size_k_block * 2;
     octx->src2_spad.size_per_thread = factx.size_v_block * 2;
-    octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
+    octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0;
     octx->dst_spad.size_per_thread  = size_vkq_acc;
 
     octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
@@ -705,6 +707,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
     octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
     octx->dst_spad.data  = octx->src3_spad.data + octx->src3_spad.size;
 
+    // FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread);
+
     if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
         worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
     }
index ff166cbcc7af5999afc5c2871ef1a323dbd0164a..7685473f4631200bc52993adf80f49e1e49c27ed 100644 (file)
@@ -143,7 +143,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
     desc->desc_size  = 0; // 1D mode
     desc->src_bypass = dma_src_l2_bypass_on;
     desc->dst_bypass = dma_dst_l2_bypass_on;
-    desc->order      = 1;
+    desc->order      = 0;
     desc->done       = 0;
     desc->src        = (void *) dptr.src;
     desc->dst        = (void *) dptr.dst;
@@ -151,8 +151,12 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
 
     q->dptr[q->push_idx] = dptr;
 
-    dmlink(q->tail, desc);
-    q->tail = (dma_descriptor_2d *) desc;
+    if (size) {
+        dmlink(q->tail, desc);
+        q->tail = (dma_descriptor_2d *) desc;
+    } else {
+        desc->done = 1;
+    }
 
     // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
     q->push_idx = (q->push_idx + 1) & q->idx_mask;
@@ -175,7 +179,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
     desc->dst_bypass     = dma_dst_l2_bypass_on;
     desc->src_comp       = 0;
     desc->dst_comp       = 0;
-    desc->order          = 1;
+    desc->order          = 0;
     desc->done           = 0;
     desc->src_stride     = src_stride;
     desc->dst_stride     = dst_stride;
@@ -197,8 +201,12 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
 
     q->dptr[q->push_idx] = dptr;
 
-    dmlink(q->tail, desc);
-    q->tail = desc;
+    if (nrows) {
+        dmlink(q->tail, desc);
+        q->tail = desc;
+    } else {
+        desc->done = 1;
+    }
 
     // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
     q->push_idx = (q->push_idx + 1) & q->idx_mask;
@@ -215,12 +223,9 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
     dma_descriptor_2d * desc = &q->desc[q->pop_idx];
 
     // Wait for desc to complete
-    while (1) {
-        dmpoll();
-        if (desc->done) {
-            break;
-        }
+    while (!desc->done) {
         // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
+        dmpoll();
     }
 
     dptr = q->dptr[q->pop_idx];
@@ -312,6 +317,54 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
     return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
 }
 
+#define DMA_CACHE_MAX_SIZE 64U
+
+typedef struct {
+    uint8_t *base;
+    uint32_t line_size;
+    uint32_t capacity;
+    uint32_t src[DMA_CACHE_MAX_SIZE];
+    uint16_t age[DMA_CACHE_MAX_SIZE];
+} dma_cache;
+
+static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity)
+{
+    c->capacity  = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity;
+    c->base      = base;
+    c->line_size = line_size;
+
+    for (unsigned i=0; i < c->capacity; i++) {
+        c->src[i] = 0;
+        c->age[i] = 0;
+    }
+}
+
+static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows)
+{
+    uint32_t o_idx = 0;
+    uint16_t o_age = 0;
+    uint8_t *  dst = 0;
+
+    for (unsigned i=0; i < c->capacity; i++) {
+        if (c->src[i] == (uint32_t) src) {
+            c->age[i] = 0;
+            dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
+            // FARF(ERROR, "dma-cache: found %p", src);
+        } else {
+            c->age[i]++;
+            if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
+        }
+    }
+    if (!dst) {
+        // FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
+        c->age[o_idx] = 0;
+        c->src[o_idx] = (uint32_t) src;
+        dst = c->base + o_idx * c->line_size; // normal nrows dma
+    }
+
+    return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
+}
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
index be9469538f63a1ca2ef8302b2cc7b4c15d4638c9..ecedadb0fead95560f55fbd7110f1e20808e3751 100644 (file)
@@ -333,8 +333,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
                     //         (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
                 }
 
-                // Skip DMA transactions from prev block (if any)
-                // No need to wait for these since the DMA is setup for in-order processing
+                // Skip output DMA transactions from prev block (if any)
+                // No need to wait for those here since we're explicitly waiting for the latest prefecthes below.
                 for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
 
                 // Compute loop