const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);
+ dma_cache m_cache;
+ dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE);
+
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
- uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
// Mask is 1D contiguous for this row
- dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
+ dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
- dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
+ dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
}
// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
- octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
+ octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
+ // FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread);
+
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
}
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
- desc->order = 1;
+ desc->order = 0;
desc->done = 0;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
q->dptr[q->push_idx] = dptr;
- dmlink(q->tail, desc);
- q->tail = (dma_descriptor_2d *) desc;
+ if (size) {
+ dmlink(q->tail, desc);
+ q->tail = (dma_descriptor_2d *) desc;
+ } else {
+ desc->done = 1;
+ }
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->src_comp = 0;
desc->dst_comp = 0;
- desc->order = 1;
+ desc->order = 0;
desc->done = 0;
desc->src_stride = src_stride;
desc->dst_stride = dst_stride;
q->dptr[q->push_idx] = dptr;
- dmlink(q->tail, desc);
- q->tail = desc;
+ if (nrows) {
+ dmlink(q->tail, desc);
+ q->tail = desc;
+ } else {
+ desc->done = 1;
+ }
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
dma_descriptor_2d * desc = &q->desc[q->pop_idx];
// Wait for desc to complete
- while (1) {
- dmpoll();
- if (desc->done) {
- break;
- }
+ while (!desc->done) {
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
+ dmpoll();
}
dptr = q->dptr[q->pop_idx];
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}
+#define DMA_CACHE_MAX_SIZE 64U
+
+typedef struct {
+ uint8_t *base;
+ uint32_t line_size;
+ uint32_t capacity;
+ uint32_t src[DMA_CACHE_MAX_SIZE];
+ uint16_t age[DMA_CACHE_MAX_SIZE];
+} dma_cache;
+
+static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity)
+{
+ c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity;
+ c->base = base;
+ c->line_size = line_size;
+
+ for (unsigned i=0; i < c->capacity; i++) {
+ c->src[i] = 0;
+ c->age[i] = 0;
+ }
+}
+
+static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows)
+{
+ uint32_t o_idx = 0;
+ uint16_t o_age = 0;
+ uint8_t * dst = 0;
+
+ for (unsigned i=0; i < c->capacity; i++) {
+ if (c->src[i] == (uint32_t) src) {
+ c->age[i] = 0;
+ dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
+ // FARF(ERROR, "dma-cache: found %p", src);
+ } else {
+ c->age[i]++;
+ if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
+ }
+ }
+ if (!dst) {
+ // FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
+ c->age[o_idx] = 0;
+ c->src[o_idx] = (uint32_t) src;
+ dst = c->base + o_idx * c->line_size; // normal nrows dma
+ }
+
+ return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
+}
+
#ifdef __cplusplus
} // extern "C"
#endif