id<MTLComputePipelineState> pipeline;
};
+@interface ggml_metal_kernel_wrapper : NSObject
+
+@property (nonatomic, assign) struct ggml_metal_kernel kernel;
+
+@end
+
+@implementation ggml_metal_kernel_wrapper
+- (void) dealloc {
+ [_kernel.pipeline release];
+ [super dealloc];
+}
+@end
+
enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_ADD,
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
GGML_METAL_KERNEL_TYPE_SET_I32,
GGML_METAL_KERNEL_TYPE_SET_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
dispatch_queue_t d_queue;
+ // the set of pre-compiled kernels for this context
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
+ // additional, inference-time compiled kernels
+ NSMutableDictionary * kernels_ext;
+
// capture state
bool capture_next_compute;
bool capture_started;
// - if not found, load the source and compile it
// - if that fails, return NULL
static id<MTLLibrary> ggml_metal_load_library(id<MTLDevice> device, bool use_bfloat) {
+ const int64_t t_start = ggml_time_us();
+
id<MTLLibrary> metal_library = nil;
NSError * error = nil;
NSString * src = nil;
[src release];
#endif // GGML_METAL_EMBED_LIBRARY
+ GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
+
return metal_library;
}
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40, flash_attn_ext_f16_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40, flash_attn_ext_bf16_h40, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40, flash_attn_ext_q4_0_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40, flash_attn_ext_q4_1_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40, flash_attn_ext_q5_0_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40, flash_attn_ext_q5_1_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40, flash_attn_ext_q8_0_h40, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192, flash_attn_ext_vec_f16_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192, flash_attn_ext_vec_bf16_h192, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192, flash_attn_ext_vec_q4_0_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192, flash_attn_ext_vec_q4_1_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192, flash_attn_ext_vec_q5_0_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192, flash_attn_ext_vec_q5_1_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192, flash_attn_ext_vec_q8_0_h192, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128, flash_attn_ext_vec_f16_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128, flash_attn_ext_vec_bf16_hk192_hv128, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128, flash_attn_ext_vec_q4_0_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128, flash_attn_ext_vec_q4_1_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128, flash_attn_ext_vec_q5_0_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128, flash_attn_ext_vec_q5_1_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128, flash_attn_ext_vec_q8_0_hk192_hv128, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}
+ ctx->kernels_ext = [[NSMutableDictionary alloc] init];
+
return ctx;
}
+static id<MTLComputePipelineState> ggml_metal_get_kernel(struct ggml_backend_metal_context * ctx, const char * name) {
+ NSString * key = [NSString stringWithUTF8String:name];
+
+ ggml_metal_kernel_wrapper * obj = [ctx->kernels_ext objectForKey:key];
+ if (obj) {
+ return obj.kernel.pipeline;
+ }
+
+ return nil;
+}
+
+static id<MTLComputePipelineState> ggml_metal_compile_kernel(ggml_backend_t backend, const char * base, const char * name, MTLFunctionConstantValues * cv) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+ struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+ id<MTLComputePipelineState> res = nil;
+
+ @autoreleasepool {
+ NSError * error = nil;
+
+ NSString * base_func = [NSString stringWithUTF8String:base];
+
+ GGML_LOG_DEBUG("%s: compiling kernel: base = '%s', name = '%s'\n", __func__, base, name);
+
+ // TODO: make sure it is thread-safe to compile kernels in parallel
+ id<MTLFunction> metal_function = [ctx_dev->mtl_library newFunctionWithName:base_func constantValues:cv error:&error];
+ if (!metal_function) {
+ GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+
+ return nil;
+ }
+
+ struct ggml_metal_kernel kernel = {
+ /*.pipeline =*/ [ctx_dev->mtl_device newComputePipelineStateWithFunction:metal_function error:&error],
+ };
+
+ ggml_metal_kernel_wrapper * obj = [[ggml_metal_kernel_wrapper alloc] init];
+ obj.kernel = kernel;
+
+ res = obj.kernel.pipeline;
+
+ NSString * key = [NSString stringWithUTF8String:name];
+ [ctx->kernels_ext setObject:obj forKey:key];
+
+ GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) kernel.pipeline,
+ (int) kernel.pipeline.maxTotalThreadsPerThreadgroup,
+ (int) kernel.pipeline.threadExecutionWidth);
+ }
+
+ return res;
+}
+
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
+ ggml_backend_t backend, struct ggml_tensor * op,
+ bool has_mask,
+ bool has_sinks,
+ bool has_bias,
+ bool has_scap,
+ int32_t nsg) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+
+ char base[256];
+ char name[256];
+
+ @autoreleasepool {
+ MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
+
+ const int32_t dk = (int32_t) op->src[1]->ne[0];
+ const int32_t dv = (int32_t) op->src[2]->ne[0];
+
+ const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
+ const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
+
+ snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
+ "flash_attn_ext",
+ ggml_type_name(op->src[1]->type),
+ dk,
+ dv);
+
+ snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
+ "flash_attn_ext",
+ ggml_type_name(op->src[1]->type),
+ dk,
+ dv,
+ has_mask,
+ has_sinks,
+ has_bias,
+ has_scap,
+ ns10,
+ ns20,
+ nsg);
+
+ id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
+ if (res) {
+ // kernel found
+ return res;
+ }
+
+ cv = [[MTLFunctionConstantValues alloc] init];
+
+ [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 0];
+ [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 1];
+ [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 2];
+ [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 3];
+
+ [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 20];
+ [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 21];
+ [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 22];
+
+ return ggml_metal_compile_kernel(backend, base, name, cv);
+ }
+}
+
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec(
+ ggml_backend_t backend, struct ggml_tensor * op,
+ bool has_mask,
+ bool has_sinks,
+ bool has_bias,
+ bool has_scap,
+ int32_t nsg,
+ int32_t nwg) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+
+ char base[256];
+ char name[256];
+
+ @autoreleasepool {
+ MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
+
+ const int32_t dk = (int32_t) op->src[1]->ne[0];
+ const int32_t dv = (int32_t) op->src[2]->ne[0];
+
+ const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
+ const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
+
+ snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
+ "flash_attn_ext_vec",
+ ggml_type_name(op->src[1]->type),
+ dk,
+ dv);
+
+ snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
+ "flash_attn_ext_vec",
+ ggml_type_name(op->src[1]->type),
+ dk,
+ dv,
+ has_mask,
+ has_sinks,
+ has_bias,
+ has_scap,
+ ns10,
+ ns20,
+ nsg, nwg);
+
+ id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
+ if (res) {
+ // kernel found
+ return res;
+ }
+
+ cv = [[MTLFunctionConstantValues alloc] init];
+
+ [cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 0];
+ [cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 1];
+ [cv setConstantValue:&has_bias type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 2];
+ [cv setConstantValue:&has_scap type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 3];
+
+ [cv setConstantValue:&ns10 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 20];
+ [cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 21];
+ [cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 22];
+ [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 23];
+
+ return ggml_metal_compile_kernel(backend, base, name, cv);
+ }
+}
+
+static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec_reduce(
+ ggml_backend_t backend, struct ggml_tensor * op,
+ int32_t dv,
+ int32_t nwg) {
+ struct ggml_backend_metal_context * ctx = backend->context;
+
+ char base[256];
+ char name[256];
+
+ @autoreleasepool {
+ MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
+
+ snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
+ snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
+
+ id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
+ if (res) {
+ // kernel found
+ return res;
+ }
+
+ cv = [[MTLFunctionConstantValues alloc] init];
+
+ [cv setConstantValue:&dv type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 0];
+ [cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 1];
+
+ return ggml_metal_compile_kernel(backend, base, name, cv);
+ }
+
+ GGML_UNUSED(op);
+}
+
static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
GGML_LOG_INFO("%s: deallocating\n", __func__);
[ctx->kernels[i].pipeline release];
}
+ if (ctx->kernels_ext) {
+ [ctx->kernels_ext release];
+ ctx->kernels_ext = nil;
+ }
+
Block_release(ctx->encode_async);
[ctx->queue release];
{
nsg = N_SG_Q8_0;
nr0 = N_R0_Q8_0;
+ smem = 32*sizeof(float)*N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
if (smem > 0) {
[encoder setThreadgroupMemoryLength:smem atIndex:0];
}
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+ if (src0t == GGML_TYPE_Q8_0) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ }
}
} break;
case GGML_OP_MUL_MAT_ID:
{
nsg = N_SG_Q8_0;
nr0 = N_R0_Q8_0;
+ smem = 32*sizeof(float)*N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
} break;
case GGML_TYPE_MXFP4:
if (smem > 0) {
[encoder setThreadgroupMemoryLength:smem atIndex:0];
}
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+
+ if (src0t == GGML_TYPE_Q8_0) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ }
}
} break;
case GGML_OP_GET_ROWS:
float scale;
float max_bias;
float logit_softcap;
+
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
scale /= logit_softcap;
}
+ const bool has_mask = src3 != NULL;
+ const bool has_sinks = src4 != NULL;
+ const bool has_bias = max_bias != 0.0f;
+ const bool has_scap = logit_softcap != 0.0f;
+
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- id<MTLComputePipelineState> pipeline = nil;
-
- bool use_vec_kernel = false;
+ GGML_ASSERT(ne01 < 65536);
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
- if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) {
- switch (src1->type) {
- case GGML_TYPE_F16:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_BF16:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q4_0:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q4_1:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q5_0:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q5_1:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- case GGML_TYPE_Q8_0:
- {
- if (ne00 == 192 && ne20 == 128) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
- } else if (ne00 == 576 && ne20 == 512) {
- pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
- } else {
- switch (ne00) {
- case 40: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H40 ].pipeline; break;
- case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
- case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break;
- case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break;
- case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break;
- case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break;
- case 192: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192].pipeline; break;
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
- } break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } else {
- use_vec_kernel = true;
-
- switch (ne00) {
- case 64:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- case 96:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- case 128:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- case 192:
- {
- if (ne20 == 128) {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK192_HV128].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK192_HV128].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK192_HV128].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } else {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H192].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H192].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H192].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H192].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H192].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H192].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H192].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- }
- } break;
- case 256:
- {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } break;
- case 576:
- {
- if (ne20 == 512) {
- switch (src1->type) {
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
- case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
- case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
- default:
- {
- GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
- GGML_LOG_ERROR("add template specialization for this type\n");
- GGML_ABORT("add template specialization for this type");
- }
- }
- } else {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- } break;
- default:
- {
- GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
- GGML_LOG_ERROR("add template specialization for this size\n");
- GGML_ABORT("add template specialization for this size");
- }
- }
- }
-
- ggml_metal_kargs_flash_attn_ext args = {
- /*.ne01 =*/ ne01,
- /*.ne02 =*/ ne02,
- /*.ne03 =*/ ne03,
- /*.nb01 =*/ nb01,
- /*.nb02 =*/ nb02,
- /*.nb03 =*/ nb03,
- /*.ne11 =*/ ne11,
- /*.ne_12_2 =*/ ne12,
- /*.ne_12_3 =*/ ne13,
- /*.nb11 =*/ nb11,
- /*.nb12 =*/ nb12,
- /*.nb13 =*/ nb13,
- /*.nb21 =*/ nb21,
- /*.nb22 =*/ nb22,
- /*.nb23 =*/ nb23,
- /*.ne32 =*/ ne32,
- /*.ne33 =*/ ne33,
- /*.nb31 =*/ nb31,
- /*.nb32 =*/ nb32,
- /*.nb33 =*/ nb33,
- /*.ne1 =*/ ne1,
- /*.ne2 =*/ ne2,
- /*.ne3 =*/ ne3,
- /*.scale =*/ scale,
- /*.max_bias =*/ max_bias,
- /*.m0 =*/ m0,
- /*.m1 =*/ m1,
- /*.n_head_log2 =*/ n_head_log2,
- /*.logit_softcap =*/ logit_softcap,
- };
-
- [encoder setComputePipelineState:pipeline];
- [encoder setBytes:&args length:sizeof(args) atIndex:0];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
- if (id_src3) {
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
- }
- if (id_src4) {
- [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
- } else {
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
- }
-
- if (!use_vec_kernel) {
+ if (ne01 >= 20 || (ne00 % 32 != 0)) {
// half8x8 kernel
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
- const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
+ const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 8 == 0);
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
- // 2*(2*ncpsg + nqptg)*(nsg)
- // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
+ // 2*(2*ncpsg)
+ // ncpsg soft_max values + ncpsg mask values
//
// 16*32*(nsg)
// the shared memory needed for the simdgroups to load the KV cache
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
//
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
-
- int64_t nsgmax = 2;
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
- while (true) {
- const size_t smem = FATTN_SMEM(nsgmax);
- if (smem > device.maxThreadgroupMemoryLength/2) {
- break;
- }
- nsgmax *= 2;
- }
- nsgmax /= 2;
+ //int64_t nsgmax = 4;
+ //
+ //if (is_q) {
+ // nsgmax = 2;
+ // while (true) {
+ // const size_t smem = FATTN_SMEM(nsgmax);
+ // if (smem > device.maxThreadgroupMemoryLength/2) {
+ // break;
+ // }
+ // nsgmax *= 2;
+ // }
+ // nsgmax /= 2;
+ //}
// simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+ //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+ int32_t nsg = 4;
const size_t smem = FATTN_SMEM(nsg);
+ ggml_metal_kargs_flash_attn_ext args = {
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne_12_2 =*/ ne12,
+ /*.ne_12_3 =*/ ne13,
+ /*.ns10 =*/ nb11/nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ns20 =*/ nb21/nb20,
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb23 =*/ nb23,
+ /*.ne32 =*/ ne32,
+ /*.ne33 =*/ ne33,
+ /*.nb31 =*/ nb31,
+ /*.nb32 =*/ nb32,
+ /*.nb33 =*/ nb33,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.scale =*/ scale,
+ /*.max_bias =*/ max_bias,
+ /*.m0 =*/ m0,
+ /*.m1 =*/ m1,
+ /*.n_head_log2 =*/ n_head_log2,
+ /*.logit_softcap =*/ logit_softcap,
+ };
+
+ id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_flash_attn_ext(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg);
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ if (id_src3) {
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
+ }
+ if (id_src4) {
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+ }
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
- //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
+ //printf("smem: %zu, max: %zu, nsg = %d, ne02 = %d, ne12 = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, ne02, ne12);
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
// half4x4 kernel
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
- const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
+ const int64_t nkpsg = 1*ncpsg;
GGML_ASSERT(nqptg <= 32);
GGML_ASSERT(nqptg % 1 == 0);
// ne20*(nsg)
// each simdgroup has a full f32 head vector in shared mem to accumulate results
//
-#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
-//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
+#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
int64_t nsgmax = 2;
while (true) {
nsgmax /= 2;
// simdgroups per threadgroup (a.k.a. warps)
- const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+ //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
+ const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
int64_t nsg = 1;
while (nsg <= nsgt) {
// workgroups
// each workgroup handles nsg*nkpsg cache values
- uint16_t nwg = 1;
- if (4*nsg*nkpsg >= ne11) {
- const size_t smem = FATTN_SMEM(nsg);
+ int32_t nwg = 1;
+ if (false) {
+ // for small KV caches, we could launch a single workgroup and write the results directly to dst/
+ // however, this does not lead to significant improvement, so disabled
+ nwg = 1;
+ nsg = 4;
+ } else {
+ nwg = 32;
+ nsg = 1;
+ while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
+ nsg *= 2;
+ }
+ }
- //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+ ggml_metal_kargs_flash_attn_ext_vec args = {
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne11 =*/ ne11,
+ /*.ne_12_2 =*/ ne12,
+ /*.ne_12_3 =*/ ne13,
+ /*.ns10 =*/ nb11/nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.ns20 =*/ nb21/nb20,
+ /*.nb21 =*/ nb21,
+ /*.nb22 =*/ nb22,
+ /*.nb23 =*/ nb23,
+ /*.ne32 =*/ ne32,
+ /*.ne33 =*/ ne33,
+ /*.nb31 =*/ nb31,
+ /*.nb32 =*/ nb32,
+ /*.nb33 =*/ nb33,
+ /*.ne1 =*/ ne1,
+ /*.ne2 =*/ ne2,
+ /*.ne3 =*/ ne3,
+ /*.scale =*/ scale,
+ /*.max_bias =*/ max_bias,
+ /*.m0 =*/ m0,
+ /*.m1 =*/ m1,
+ /*.n_head_log2 =*/ n_head_log2,
+ /*.logit_softcap =*/ logit_softcap,
+ };
- // using 1 workgroup -> write the result directly into dst
- [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
- [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+ id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_flash_attn_ext_vec(backend, node, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
- [encoder setThreadgroupMemoryLength:smem atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ GGML_ASSERT(nsg*32 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ if (id_src3) {
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
} else {
- nwg = 32;
- nsg = MIN(4, nsg);
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
+ }
+ if (id_src4) {
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+ }
- const size_t smem = FATTN_SMEM(nsg);
+ const size_t smem = FATTN_SMEM(nsg);
- //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
- GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+ //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
+ GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
+
+ if (nwg == 1) {
+ // using 1 workgroup -> write the result directly into dst
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
+ [encoder setThreadgroupMemoryLength:smem atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+ } else {
// sanity checks
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
- [encoder setBuffer:h_tmp offset:0 atIndex:6];
- [encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
+ [encoder setBuffer:h_tmp offset:0 atIndex:6];
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
// reduce the results from the workgroups
{
- ggml_metal_kargs_flash_attn_ext_reduce args0 = {
+ ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
nrows,
- ne20,
};
- id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
+ id<MTLComputePipelineState> pipeline0 = ggml_metal_get_pipeline_flash_attn_ext_vec_reduce(backend, node, ne20, nwg);
[encoder setComputePipelineState:pipeline0];
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
- [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*nwg, 1, 1)];
}
}
#undef FATTN_SMEM
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
+#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
+
+#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
+
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
}
-template<typename block_q_type, int nr0, int nsg, int nw, typename args_t>
+template<short NR0, short NW>
+static inline void helper_mv_reduce_and_write(
+ device float * dst_f32,
+ float sumf[NR0],
+ const int r0,
+ const int ne01,
+ ushort tiisg,
+ ushort sgitg,
+ threadgroup char * shmem) {
+ threadgroup float * shmem_f32[NR0];
+
+ for (short row = 0; row < NR0; ++row) {
+ shmem_f32[row] = (threadgroup float *) shmem + NW*row;
+
+ if (sgitg == 0) {
+ shmem_f32[row][tiisg] = 0.0f;
+ }
+
+ sumf[row] = simd_sum(sumf[row]);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short row = 0; row < NR0; ++row) {
+ if (tiisg == 0) {
+ shmem_f32[row][sgitg] = sumf[row];
+ }
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short row = 0; row < NR0 && r0 + row < ne01; ++row) {
+ float tot = simd_sum(shmem_f32[row][tiisg]);
+
+ if (tiisg == 0 && sgitg == 0) {
+ dst_f32[r0 + row] = tot;
+ }
+ }
+}
+
+template<typename block_q_type, short NR0, short NSG, short NW, typename args_t>
void mul_vec_q_n_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
- const int nb = args.ne00/QK4_0;
+ constexpr short NQ = 16;
- const int r0 = tgpig.x;
- const int r1 = tgpig.y;
- const int im = tgpig.z;
+ const int nb = args.ne00/QK4_0;
- const int first_row = (r0 * nsg + sgitg) * nr0;
+ const int r0 = (tgpig.x*NSG + sgitg)*NR0;
+ //const int r0 = tgpig.x*NR0;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
- //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+ //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
- device const block_q_type * ax[nr0];
- for (int row = 0; row < nr0; ++row) {
- const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ device const block_q_type * ax[NR0];
+ FOR_UNROLL (int row = 0; row < NR0; ++row) {
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
}
- float yl[16]; // src1 vector cache
- float sumf[nr0] = {0.f};
+ float sumf[NR0] = {0.f};
- const short ix = (tiisg/2);
- const short il = (tiisg%2)*8;
+ const short ix = (tiisg/(NW/NQ));
+ const short il = (tiisg%(NW/NQ))*8;
- device const float * yb = y + ix*QK4_0 + il;
+ //const int ib0 = sgitg*NQ + ix;
+ const int ib0 = ix;
+
+ float yl[16]; // src1 vector cache
+
+ //device const float * yb = y + ix*QK4_0 + il;
+ device const float * yb = y + ib0*QK4_0 + il;
// each thread in a SIMD group deals with half a block.
- for (int ib = ix; ib < nb; ib += nw/2) {
+ //for (int ib = ib0; ib < nb; ib += NSG*NQ) {
+ for (int ib = ib0; ib < nb; ib += NQ) {
float sumy[2] = { 0.f, 0.f };
-#pragma unroll
- for (short i = 0; i < 8; i += 2) {
+ FOR_UNROLL (short i = 0; i < 8; i += 2) {
sumy[0] += yb[i + 0] + yb[i + 1];
yl[i + 0] = yb[i + 0];
yl[i + 1] = yb[i + 1]/256.f;
yl[i + 9] = yb[i + 17]/4096.f;
}
-#pragma unroll
- for (short row = 0; row < nr0; row++) {
+ FOR_UNROLL (short row = 0; row < NR0; row++) {
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
}
yb += QK4_0 * 16;
+ //yb += NSG*NQ*QK4_0;
}
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
- for (int row = 0; row < nr0; ++row) {
+ //helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
+
+ for (int row = 0; row < NR0; ++row) {
const float tot = simd_sum(sumf[row]);
- if (tiisg == 0 && first_row + row < args.ne01) {
- dst_f32[first_row + row] = tot;
+ if (tiisg == 0 && r0 + row < args.ne01) {
+ dst_f32[r0 + row] = tot;
}
}
}
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
-#define NB_Q8_0 8
-
-template<int nr0, int nsg, int nw, typename args_t>
+template<short NR0, short NSG, short NW, typename args_t>
void kernel_mul_mv_q8_0_f32_impl(
args_t args,
device const char * src0,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
+ constexpr short NQ = 8;
+
const int nb = args.ne00/QK8_0;
- const int r0 = tgpig.x;
+ const int r0 = tgpig.x*NR0;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr0;
-
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
- //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
- const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+ //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
- device const block_q8_0 * ax[nr0];
- for (int row = 0; row < nr0; ++row) {
- const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ device const block_q8_0 * ax[NR0];
+ FOR_UNROLL (short row = 0; row < NR0; ++row) {
+ const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
}
- float yl[NB_Q8_0];
- float sumf[nr0] = { 0.f };
+ float sumf[NR0] = { 0.f };
+
+ const short ix = tiisg/(NW/NQ);
+ const short il = tiisg%(NW/NQ);
+
+ const int ib0 = sgitg*NQ + ix;
- const short ix = tiisg/4;
- const short il = tiisg%4;
+ float yl[NQ];
- device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
+ device const float * yb = y + ib0*QK8_0 + il*NQ;
- // each thread in a SIMD group deals with NB_Q8_0 quants at a time
- for (int ib = ix; ib < nb; ib += nw/4) {
- for (short i = 0; i < NB_Q8_0; ++i) {
+ // each thread in a SIMD group deals with NQ quants at a time
+ for (int ib = ib0; ib < nb; ib += NSG*NQ) {
+ for (short i = 0; i < NQ; ++i) {
yl[i] = yb[i];
}
- for (short row = 0; row < nr0; row++) {
- device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
+ for (short row = 0; row < NR0; row++) {
+ device const int8_t * qs = ax[row][ib].qs + il*NQ;
+
float sumq = 0.f;
- for (short iq = 0; iq < NB_Q8_0; ++iq) {
- sumq += qs[iq] * yl[iq];
+ FOR_UNROLL (short i = 0; i < NQ; ++i) {
+ sumq += qs[i] * yl[i];
}
+
sumf[row] += sumq*ax[row][ib].d;
}
- yb += nw*NB_Q8_0;
+ yb += NSG*NQ*QK8_0;
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
- for (int row = 0; row < nr0; ++row) {
- const float tot = simd_sum(sumf[row]);
-
- if (tiisg == 0 && first_row + row < args.ne01) {
- dst_f32[first_row + row] = tot;
- }
- }
+ helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
}
[[host_name("kernel_mul_mv_q8_0_f32")]]
device const char * src0,
device const char * src1,
device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
// mat-vec kernel processing in chunks of float4
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
}
+constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
+constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
+constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
+constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
+
+//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
+//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
+//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];
+
+constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]];
+constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]];
+constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]];
+
// ref: https://arxiv.org/pdf/2307.08691.pdf
template<
typename q_t, // query types in shared memory
typename qk_t, // Q*K types
typename qk8x8_t,
typename s_t, // soft-max types
+ typename s2_t,
typename s8x8_t,
typename o_t, // attention accumulation types
typename o4_t,
typename vd4x4_t, // value type in device memory
short nl_v,
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
- short DK, // K head size
- short DV, // V head size
- short Q = 8, // queries per threadgroup
- short KV = 8, // key/value processed per each simdgroup
- short C = 32> // cache items per threadgroup
-kernel void kernel_flash_attn_ext(
+ short DK, // K head size
+ short DV, // V head size
+ short Q, // queries per threadgroup
+ short C, // cache items per threadgroup
+ short NSG> // number of simd groups
+void kernel_flash_attn_ext_impl(
constant ggml_metal_kargs_flash_attn_ext & args,
device const char * q,
device const char * k,
device const char * mask,
device const char * sinks,
device char * dst,
- threadgroup half * shmem_f16 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- const short nsg = ntg.y; // number of simdgroups
-
- const int iq3 = tgpig[2];
- const int iq2 = tgpig[1];
- const int iq1 = tgpig[0]*Q;
+ threadgroup half * shmem_f16,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+ const ushort iq3 = tgpig[2];
+ const ushort iq2 = tgpig[1];
+ const ushort iq1 = tgpig[0]*Q;
+
+#define NS10 (FC_flash_attn_ext_ns10)
+#define NS20 (FC_flash_attn_ext_ns20)
+
+ // note: I had some concerns that using this instead of the ugly macros above was affecting performance
+ // need to re-check carefully and if no regressions are observerd - remove the macros
+ // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler
+ // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC
+ //const short NS10 = FC_flash_attn_ext_ns10;
+ //const short NS20 = FC_flash_attn_ext_ns20;
+
+ constexpr short KV = 8;
constexpr short DK4 = DK/4;
constexpr short DK8 = DK/8;
constexpr short DK16 = DK/16;
constexpr short DV4 = DV/4;
- constexpr short DV8 = DV/8;
+ //constexpr short DV8 = DV/8;
constexpr short DV16 = DV/16;
+ constexpr short PV = PAD2(DV, 64);
+ constexpr short PV4 = PV/4;
+ constexpr short PV8 = PV/8;
+ //constexpr short PV16 = PV/16;
+
constexpr short NW = N_SIMDWIDTH;
- constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
+ constexpr short NQ = Q/NSG;
+ constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float)
+
+ constexpr short TS = 2*SH;
+ constexpr short T = DK + 2*PV; // shared memory size per query in (half)
+
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t
+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper)
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK);
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix
+ threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t
+
+ threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory
+ threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t
- const short TS = nsg*SH; // shared memory size per query in (s_t == float)
- const short T = 2*DK + 2*TS; // shared memory size per query in (half)
+ threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory
+ threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
+ // mask storage in shared mem
+ threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
- threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
- threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+ // per-query mask pointers
+ device const half2 * pm2[NQ];
- threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
- threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
- // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
- o8x8_t lo[DV8];
+ pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
+ }
+
+ {
+ q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
+
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+ k += ikv2*args.nb12 + ikv3*args.nb13;
+ v += ikv2*args.nb22 + ikv3*args.nb23;
+ }
// load heads from Q to shared memory
- for (short j = sgitg; j < Q; j += nsg) {
- device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
+
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01);
for (short i = tiisg; i < DK4; i += NW) {
if (iq1 + j < args.ne01) {
}
}
- // zero out lo
- for (short i = 0; i < DV8; ++i) {
- lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
- }
+ // zero out
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
+
+ for (short i = tiisg; i < DV4; i += NW) {
+ so4[j*PV4 + i] = 0;
+ }
- // zero out shared memory SH
- for (short j = 0; j < Q; ++j) {
for (short i = tiisg; i < SH; i += NW) {
- ss[j*TS + i] = 0.0f;
+ ss[j*SH + i] = 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
- {
- float S[Q] = { [0 ... Q-1] = 0.0f };
- float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
-
- // thread indices inside the simdgroup
- // TODO: see if we can utilize quad-group functions for better performance
- // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
- const short tx = tiisg%4;
- const short ty = tiisg/4;
-
- // broadcast kv
- //const short rk2 = args.ne02/args.ne12;
- //const short rk3 = args.ne03/args.ne13;
-
- const short ikv2 = iq2/(args.ne02/args.ne_12_2);
- const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+ float S[NQ] = { [0 ... NQ-1] = 0.0f };
- const bool has_mask = mask != q;
+ {
+ float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 };
float slope = 1.0f;
// ALiBi
- if (args.max_bias > 0.0f) {
+ if (FC_flash_attn_ext_has_bias) {
const short h = iq2;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
- const int ic = ic0 + C*sgitg;
- if (ic >= args.ne11) {
- break;
- }
-
- if (has_mask) {
- // used to detect blocks full of -INF
- float smax = -INFINITY;
+ for (int ic = 0; ic < args.ne11; ic += C) {
+ // read the mask into shared mem
+ if (FC_flash_attn_ext_has_mask) {
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
+
+ sm2[j*SH + tiisg] = pm2[jj][tiisg];
+ pm2[jj] += NW;
+ }
- // load the mask in shared memory
- #pragma unroll(Q)
- for (short j = 0; j < Q; ++j) {
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- const float m = pm[ic + tiisg];
+ // used to detect blocks full of -INF
+ // skip only when the entire threadgroup is masked
+ half2 smax2(-MAXHALF/2, -MAXHALF/2);
- ss[j*TS + C + tiisg] = m;
- smax = max(smax, m);
+ FOR_UNROLL (short j = 0; j < Q; ++j) {
+ smax2 = max(smax2, sm2[j*SH + tiisg]);
}
- smax = simd_max(smax);
+ smax2 = simd_max(smax2);
+
+ if (max(smax2[0], smax2[1]) <= -MAXHALF/2) {
+ // this barrier is important
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- if (smax == -INFINITY) {
continue;
}
}
// Q*K^T
- {
- for (short cc = 0; cc < C/8; ++cc) {
+ // this is compile-time check, so it does not have runtime overhead
+ if (is_same<kd4x4_t, k4x4_t>::value) {
+ // we can read directly from global memory
+ device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
+ threadgroup const q_t * pq = sq;
+ threadgroup s_t * ps = ss;
+
+ pk += sgitg*(8*NS10);
+ ps += sgitg*(8*1);
+
+ static_assert((C/8) % NSG == 0, "");
+
+ constexpr short NC = (C/8)/NSG;
+
+ // TODO: not good to unroll for large contexts - not sure why?
+ for (short cc = 0; cc < NC; ++cc) {
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
- // this is compile-time check, so it does not have runtime overhead
- if (is_same<kd4x4_t, k4x4_t>::value) {
- // we can read directly from global memory
- device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+ if (DK8 % 16 != 0) {
+ k8x8_t mk;
+ q8x8_t mq;
+
+ FOR_UNROLL (short i = 0; i < DK8; ++i) {
+ simdgroup_barrier(mem_flags::mem_none);
+
+ simdgroup_load(mk, pk, NS10, 0, true);
+ simdgroup_load(mq, pq, DK);
- #pragma unroll(DK8)
- for (short i = 0; i < DK8; ++i) {
- k8x8_t mk;
- simdgroup_load(mk, pk + i*8, args.nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+ simdgroup_barrier(mem_flags::mem_none);
- q8x8_t mq;
- simdgroup_load(mq, sq + i*8, DK);
simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+ pk += 8;
+ pq += 8;
}
} else {
- for (short ii = 0; ii < DK16; ii += 4) {
- device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+ k8x8_t mk[2];
+ q8x8_t mq[2];
- if (DK16%4 == 0) {
- // the head is evenly divisible by 4*16 = 64, so no need for bound checks
- {
- k4x4_t tmp;
- deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
- sk4x4[4*ty + tx] = tmp;
- }
+ FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
+ simdgroup_barrier(mem_flags::mem_none);
- simdgroup_barrier(mem_flags::mem_threadgroup);
+ simdgroup_load(mk[0], pk + 0*8, NS10, 0, true);
+ simdgroup_load(mk[1], pk + 1*8, NS10, 0, true);
- #pragma unroll(4)
- for (short k = 0; k < 4; ++k) {
- k8x8_t mk;
- q8x8_t mq;
+ simdgroup_load(mq[0], pq + 0*8, DK);
+ simdgroup_load(mq[1], pq + 1*8, DK);
- simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+ simdgroup_barrier(mem_flags::mem_none);
- simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
- }
- } else {
- if (ii + tx < DK16) {
- k4x4_t tmp;
- deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
- sk4x4[4*ty + tx] = tmp;
- }
+ simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
+ simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
- simdgroup_barrier(mem_flags::mem_threadgroup);
+ pk += 16;
+ pq += 16;
+ }
+ }
- for (short k = 0; k < 4 && ii + k < DK16; ++k) {
- k8x8_t mk;
- q8x8_t mq;
+ simdgroup_store(mqk, ps, SH, 0, false);
- simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+ pk += 8*(NSG*NS10 - DK8);
+ pq += 8*(NSG*0 - DK8);
+ ps += 8*(NSG);
+ }
+ } else {
+ // TODO: this is the quantized K cache branch - not optimized yet
+ for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) {
+ const short cc = ccc*NSG + sgitg;
- simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
- simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
- simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
- }
+ const short tx = tiisg%4;
+ const short ty = tiisg/4;
+
+ qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
+
+ for (short ii = 0; ii < DK16; ii += 4) {
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
+
+ if (DK16%4 == 0) {
+ // the head is evenly divisible by 4*16 = 64, so no need for bound checks
+ {
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ FOR_UNROLL (short k = 0; k < 4; ++k) {
+ k8x8_t mk;
+ q8x8_t mq;
+
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+ }
+ } else {
+ if (ii + tx < DK16) {
+ k4x4_t tmp;
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
+ sk4x4[4*ty + tx] = tmp;
+ }
+
+ simdgroup_barrier(mem_flags::mem_threadgroup);
+
+ for (short k = 0; k < 4 && ii + k < DK16; ++k) {
+ k8x8_t mk;
+ q8x8_t mq;
+
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
+
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
+ simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
+ simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
}
}
}
- // cast qk_t -> s_t
- //s8x8_t mqks(1.0f);
- //simdgroup_multiply(mqks, mqk, mqks);
- //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
-
- simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
+ simdgroup_store(mqk, ss + 8*cc, SH, 0, false);
}
}
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
// online softmax
- {
- for (ushort j = 0; j < Q; ++j) {
- const float m = M[j];
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
- // scale and apply the logitcap / mask
- float s = ss[j*TS + tiisg]*args.scale;
+ const float m = M[jj];
- if (args.logit_softcap != 0.0f) {
- s = args.logit_softcap*precise::tanh(s);
- }
+ // scale and apply the logitcap / mask
+ float2 s2 = ss2[j*SH/2 + tiisg]*args.scale;
+
+ if (FC_flash_attn_ext_has_scap) {
+ s2 = args.logit_softcap*precise::tanh(s2);
+ }
- // mqk = mqk + mask*slope
- s += slope*ss[j*TS + C + tiisg];
+ // mqk = mqk + slope*mask
+ if (FC_flash_attn_ext_has_bias) {
+ s2 += s2_t(sm2[j*SH + tiisg])*slope;
+ } else {
+ s2 += s2_t(sm2[j*SH + tiisg]);
+ }
- M[j] = simd_max(max(M[j], s));
+ M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
- const float ms = exp(m - M[j]);
- const float vs = exp(s - M[j]);
+ const float ms = exp(m - M[jj]);
+ const float2 vs2 = exp(s2 - M[jj]);
- S[j] = S[j]*ms + simd_sum(vs);
+ S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]);
- // the P matrix from the paper (Q rows, C columns)
- ss[j*TS + tiisg] = vs;
+ // the P matrix from the paper (Q rows, C columns)
+ ss2[j*SH/2 + tiisg] = vs2;
- // create a QxQ diagonal matrix for rescaling the output
- if (tiisg == j) {
- ss[j*TS + 2*C + j] = ms;
+ if (DV4 % NW == 0) {
+ FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
+ const short i = ii*NW + tiisg;
+
+ so4[j*PV4 + i] *= ms;
+ }
+ } else {
+ for (short i = tiisg; i < DV4; i += NW) {
+ so4[j*PV4 + i] *= ms;
}
}
}
- // O = diag(ms)*O
- {
- s8x8_t ms;
- simdgroup_load(ms, ss + 2*C, TS, 0, false);
-
- #pragma unroll(DV8)
- for (short i = 0; i < DV8; ++i) {
- simdgroup_multiply(lo[i], ms, lo[i]);
- }
- }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
// O = O + (Q*K^T)*V
{
- for (short cc = 0; cc < C/8; ++cc) {
- s8x8_t vs;
- simdgroup_load(vs, ss + 8*cc, TS, 0, false);
+ // we can read directly from global memory
+ if (is_same<vd4x4_t, v4x4_t>::value) {
+ static_assert(PV8 % NSG == 0, "");
+
+ constexpr short NO = PV8/NSG;
- if (is_same<vd4x4_t, v4x4_t>::value) {
- // we can read directly from global memory
- device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+ o8x8_t lo[NO];
- #pragma unroll(DV8)
- for (short i = 0; i < DV8; ++i) {
- v8x8_t mv;
- simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20
+ {
+ auto sot = so + 8*sgitg;
- simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]);
+ FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
+ simdgroup_load(lo[ii], sot, PV, 0, false);
+
+ sot += 8*NSG;
}
- } else {
- for (short ii = 0; ii < DV16; ii += 4) {
- device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+ }
+
+ {
+ auto sst = ss;
+
+ device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
+
+ pv += 8*sgitg;
+
+ FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
+ s8x8_t vs;
+ simdgroup_load(vs, sst, SH, 0, false);
+
+ FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
+ v8x8_t mv;
+
+ simdgroup_load(mv, pv, NS20, 0, false);
+ simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]);
+
+ pv += 8*NSG;
+ }
+
+ pv += 8*(NS20 - NO*NSG);
+ sst += 8;
+ }
+ }
+
+ {
+ auto sot = so + 8*sgitg;
+
+ FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
+ simdgroup_store(lo[ii], sot, PV, 0, false);
+
+ sot += 8*NSG;
+ }
+ }
+ } else {
+ // TODO: this is the quantized V cache branch - not optimized yet
+
+ const short tx = tiisg%4;
+ const short ty = tiisg/4;
+
+ for (short cc = 0; cc < C/8; ++cc) {
+ s8x8_t vs;
+ simdgroup_load(vs, ss + 8*cc, SH, 0, false);
+
+ for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
if (DV16%4 == 0) {
// no need for bound checks
simdgroup_barrier(mem_flags::mem_threadgroup);
- #pragma unroll(4)
- for (short k = 0; k < 4; ++k) {
- v8x8_t mv;
+ FOR_UNROLL (short k = 0; k < 4; ++k) {
+ v8x8_t mv[2];
+ o8x8_t lo[2];
- simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
+ simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+ simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
- simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
+ simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
+ simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
+
+ simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+ simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
}
} else {
if (ii + tx < DV16) {
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short k = 0; k < 4 && ii + k < DV16; ++k) {
- v8x8_t mv;
+ v8x8_t mv[2];
+ o8x8_t lo[2];
+
+ simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
+ simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
+ simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+ simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
- simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]);
+ simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
+ simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
- simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
- simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]);
+ simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
+ simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
}
}
}
}
}
}
- }
- if (sinks != q && sgitg == 0) {
- for (ushort j = 0; j < Q; ++j) {
- const float m = M[j];
- const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
- M[j] = simd_max(max(M[j], s));
+ if (FC_flash_attn_ext_has_sinks) {
+ FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
- const float ms = exp(m - M[j]);
- const float vs = exp(s - M[j]);
+ const float m = M[jj];
+ const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
- S[j] = S[j]*ms + simd_sum(vs);
+ M[jj] = simd_max(max(M[jj], s));
- if (tiisg == j) {
- ss[j*TS + 2*C + j] = ms;
- }
- }
+ const float ms = exp(m - M[jj]);
+ const float vs = exp(s - M[jj]);
- // O = diag(ms)*O
- {
- s8x8_t ms;
- simdgroup_load(ms, ss + 2*C, TS, 0, false);
+ S[jj] = S[jj]*ms + simd_sum(vs);
- #pragma unroll(DV8)
- for (short i = 0; i < DV8; ++i) {
- simdgroup_multiply(lo[i], ms, lo[i]);
+ for (short i = tiisg; i < DV4; i += NW) {
+ so4[j*PV4 + i] *= ms;
}
}
}
-
- // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
- for (short j = tiisg; j < Q; j += NW) {
- ss[j*TS + 0] = S[j];
- ss[j*TS + 1] = M[j];
- }
}
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation
- threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK);
-
- // store result to shared memory in F32
- if (sgitg == 0) {
- for (short i = 0; i < DV8; ++i) {
- //simdgroup_store(lo[i], so + i*8, DV, 0, false);
- simdgroup_float8x8 t(1.0f);
- simdgroup_multiply(t, lo[i], t);
- simdgroup_store(t, so + i*8, DV, 0, false);
+ // store to global memory
+ for (short jj = 0; jj < NQ; ++jj) {
+ const short j = jj*NSG + sgitg;
+ if (iq1 + j >= args.ne01) {
+ break;
}
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- // reduce the warps sequentially
- for (ushort sg = 1; sg < nsg; ++sg) {
- if (sgitg == sg) {
- for (short j = tiisg; j < Q; j += NW) {
- const float S0 = ss[j*TS - 1*SH + 0];
- const float S1 = ss[j*TS + 0];
- const float M0 = ss[j*TS - 1*SH + 1];
- const float M1 = ss[j*TS + 1];
-
- const float M = max(M0, M1);
-
- float ms0 = exp(M0 - M);
- float ms1 = exp(M1 - M);
+ device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
- const float S = S0*ms0 + S1*ms1;
+ const float scale = 1.0f/S[jj];
- ss[j*TS + 0] = S;
- ss[j*TS + 1] = M;
+ if (DV4 % NW == 0) {
+ FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
+ const short i = ii*NW + tiisg;
- ss[j*TS + 2*C + j - 1*SH] = ms0;
- ss[j*TS + 2*C + j ] = ms1;
+ dst4[i] = (float4) so4[j*PV4 + i]*scale;
}
-
- //simdgroup_barrier(mem_flags::mem_threadgroup);
-
- // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
- {
- s8x8_t ms0;
- s8x8_t ms1;
-
- simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false);
- simdgroup_load(ms1, ss + 2*C, TS, 0, false);
-
- #pragma unroll(DV8)
- for (short i = 0; i < DV8; ++i) {
- simdgroup_float8x8 t;
-
- simdgroup_load (t, so + i*8, DV, 0, false);
- simdgroup_multiply(t, ms0, t);
-
- simdgroup_multiply_accumulate(t, ms1, lo[i], t);
- simdgroup_store(t, so + i*8, DV, 0, false);
- }
+ } else {
+ for (short i = tiisg; i < DV4; i += NW) {
+ dst4[i] = (float4) so4[j*PV4 + i]*scale;
}
}
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
}
- threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK);
-
- // final rescale with 1/S and store to global memory
- for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
- const float S = 1.0f/sf[j*TS + 0];
-
- device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
+#undef NS10
+#undef NS20
+}
- for (short i = tiisg; i < DV4; i += NW) {
- dst4[i] = (float4) so4[j*DV4 + i]*S;
- }
+template<
+ typename q_t, // query types in shared memory
+ typename q4_t,
+ typename q8x8_t,
+ typename k_t, // key types in shared memory
+ typename k4x4_t,
+ typename k8x8_t,
+ typename v_t, // value types in shared memory
+ typename v4x4_t,
+ typename v8x8_t,
+ typename qk_t, // Q*K types
+ typename qk8x8_t,
+ typename s_t, // soft-max types
+ typename s2_t,
+ typename s8x8_t,
+ typename o_t, // attention accumulation types
+ typename o4_t,
+ typename o8x8_t,
+ typename kd4x4_t, // key type in device memory
+ short nl_k,
+ void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
+ typename vd4x4_t, // value type in device memory
+ short nl_v,
+ void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
+ short DK, // K head size
+ short DV, // V head size
+ short Q = 8, // queries per threadgroup
+ short C = 64> // cache items per threadgroup
+kernel void kernel_flash_attn_ext(
+ constant ggml_metal_kargs_flash_attn_ext & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device const char * sinks,
+ device char * dst,
+ threadgroup half * shmem_f16 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
+#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
+ switch (FC_flash_attn_ext_nsg) {
+ // note: disabled cases to reduce library load time
+ //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
+ //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
+ case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
}
+#undef FWD_TMPL
+#undef FWD_ARGS
}
// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
// template to be able to explore different combinations
//
#define FA_TYPES \
- float, float4, simdgroup_float8x8, \
+ half, half4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
half, half4x4, simdgroup_half8x8, \
float, simdgroup_float8x8, \
- float, simdgroup_float8x8, \
- half, half4, simdgroup_half8x8
- //float, float4, simdgroup_float8x8
+ float, float2, simdgroup_float8x8, \
+ float, float4, simdgroup_float8x8
+ //half, half4, simdgroup_half8x8
#define FA_TYPES_BF \
bfloat, bfloat4, simdgroup_bfloat8x8, \
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
float, simdgroup_float8x8, \
- float, simdgroup_float8x8, \
+ float, float2, simdgroup_float8x8, \
half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
-template [[host_name("kernel_flash_attn_ext_f16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
-template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
-template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
-template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
-template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
+template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
+template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
+template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
+template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_bf16_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
-template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
-template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
-template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
-template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif
-template [[host_name("kernel_flash_attn_ext_q4_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
-
-template [[host_name("kernel_flash_attn_ext_q4_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
-
-template [[host_name("kernel_flash_attn_ext_q5_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
-
-template [[host_name("kernel_flash_attn_ext_q5_1_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
-
-template [[host_name("kernel_flash_attn_ext_q8_0_h40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
-template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
-template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
-template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
+
+template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
+
+template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
+
+template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
+
+template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
+template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
#undef FA_TYPES
#undef FA_TYPES_BF
+constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];
+constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
+constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
+constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
+
+//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
+//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
+//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]];
+
+constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]];
+constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]];
+constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]];
+constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]];
+
template<
typename q4_t, // query types in shared memory
typename k4_t, // key types in shared memory
short DV, // V head size
short NE = 4, // head elements per thread
short Q = 1, // queries per threadgroup
- short C = 32> // cache items per threadgroup
-kernel void kernel_flash_attn_ext_vec(
- constant ggml_metal_kargs_flash_attn_ext & args,
+ short C = 32, // cache items per threadgroup
+ short NSG> // number of simd groups
+void kernel_flash_attn_ext_vec_impl(
+ constant ggml_metal_kargs_flash_attn_ext_vec & args,
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device const char * sinks,
device char * dst,
- constant uint16_t & nwg,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
static_assert(DK % 32 == 0, "DK must be divisible by 32");
static_assert(DV % 32 == 0, "DV must be divisible by 32");
- const short nsg = ntg.y; // number of simdgroups
- const short iwg = tgpig[2]%nwg;
+#define NWG (FC_flash_attn_ext_vec_nwg)
+
+#define NS10 (FC_flash_attn_ext_vec_ns10)
+#define NS20 (FC_flash_attn_ext_vec_ns20)
+
+ const short iwg = tgpig[2]%NWG;
- const int iq3 = tgpig[2]/nwg;
- const int iq2 = tgpig[1];
- const int iq1 = tgpig[0];
+ const ushort iq3 = tgpig[2]/NWG;
+ const ushort iq2 = tgpig[1];
+ const ushort iq1 = tgpig[0];
constexpr short DK4 = DK/4;
constexpr short DV4 = DV/4;
+
+ constexpr short PK = PAD2(DK, 128);
+ constexpr short PK4 = PK/4;
+
+ constexpr short PV = PAD2(DV, 128);
+ constexpr short PV4 = PV/4;
+
constexpr short NW = N_SIMDWIDTH;
constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
constexpr short SH = 4*C; // shared memory per simdgroup
- const short T = DK + nsg*SH; // shared memory size per query in (half)
+ static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
+ static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
+
+ const short T = PK + NSG*SH; // shared memory size per query in (half)
- //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
- threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t
+ threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results
- // store the result for all queries in local memory (the O matrix from the paper)
- o4_t lo[DV4/NL];
+ // store the result for all queries in shared memory (the O matrix from the paper)
+ so4 += tiisg;
+
+ {
+ q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
+
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
+
+ k += ikv2*args.nb12 + ikv3*args.nb13;
+ v += ikv2*args.nb22 + ikv3*args.nb23;
+ }
// load heads from Q to shared memory
- device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
+ device const float4 * q4 = (device const float4 *) ((device const char *) q);
- for (short i = tiisg; i < DK4; i += NW) {
- if (iq1 < args.ne01) {
+ for (short i = tiisg; i < PK4; i += NW) {
+ if (iq1 < args.ne01 && i < DK4) {
sq4[i] = (q4_t) q4[i];
} else {
sq4[i] = (q4_t) 0.0f;
}
}
- // zero out lo
+ // zero out so
for (short i = 0; i < DV4/NL; ++i) {
- lo[i] = (o4_t) 0.0f;
+ so4[i*NL] = (o4_t) 0.0f;
}
// zero out shared memory SH
{
float S = 0.0f;
- float M = -__FLT_MAX__/2;
+ float M = -FLT_MAX/2;
// thread indices inside the simdgroup
const short tx = tiisg%NL;
const short ty = tiisg/NL;
- // broadcast kv
- //const short rk2 = args.ne02/args.ne12;
- //const short rk3 = args.ne03/args.ne13;
-
- const short ikv2 = iq2/(args.ne02/args.ne_12_2);
- const short ikv3 = iq3/(args.ne03/args.ne_12_3);
-
- const bool has_mask = mask != q;
-
// pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
float slope = 1.0f;
// ALiBi
- if (args.max_bias > 0.0f) {
+ if (FC_flash_attn_ext_vec_has_bias) {
const short h = iq2;
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
+ for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
const int ic = ic0 + C*sgitg;
if (ic >= args.ne11) {
break;
}
- if (has_mask) {
+ if (FC_flash_attn_ext_vec_has_mask) {
sm[tiisg] = pm[ic + tiisg];
}
// Q*K^T
{
- // each simdgroup processes 1 query and NE (NW/NL) head elements
- for (short cc = 0; cc < C/NE; ++cc) {
- qk_t mqk = 0.0f;
+ device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
+ threadgroup const q4_t * pq4 = sq4;
- device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11 + ikv2*args.nb12 + ikv3*args.nb13));
+ pk4 += ty*NS10/4 + tx;
+ pq4 += tx;
- #pragma unroll(DK4/NL)
- for (short ii = 0; ii < DK4; ii += NL) {
- const short i = ii + tx;
+ qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f };
+
+ // each simdgroup processes 1 query and NE (NW/NL) cache elements
+ FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
+ if (is_same<kd4_t, k4_t>::value) {
+ FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
+ mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
+ }
+ } else {
+ device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
k4_t mk;
- deq_k_t4(pk + i/nl_k, i%nl_k, mk);
- // note: this is less precise than the version below
- //mqka[0] += dot(mq[0], mk[0]);
- //mqka[1] += dot(mq[1], mk[1]);
- //mqka[2] += dot(mq[2], mk[2]);
- //mqka[3] += dot(mq[3], mk[3]);
+ FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
+ const short i = ii*NL + tx;
- //q4x4_t mq = sq4x4[i];
- //mqka[0] += dot((float4) mq[0], (float4) mk[0]);
- //mqka[1] += dot((float4) mq[1], (float4) mk[1]);
- //mqka[2] += dot((float4) mq[2], (float4) mk[2]);
- //mqka[3] += dot((float4) mq[3], (float4) mk[3]);
+ deq_k_t4(pk + i/nl_k, i%nl_k, mk);
- mqk += dot((float4) mk, (float4) sq4[i]);
+ mqk[cc] += dot((float4) mk, (float4) sq4[i]);
+ }
}
- static_assert(NE > 1, "NE must be > 1"); // note: not sure why NE == 1 fails
+ if (NE == 1) {
+ mqk[cc] = simd_sum(mqk[cc]);
+ } else {
+ // simdgroup reduce (NE = 4)
+ // [ 0 .. 7] -> [ 0]
+ // [ 8 .. 15] -> [ 8]
+ // [16 .. 23] -> [16]
+ // [24 .. 31] -> [24]
+ if (NE <= 1) {
+ mqk[cc] += simd_shuffle_down(mqk[cc], 16);
+ }
+ if (NE <= 2) {
+ mqk[cc] += simd_shuffle_down(mqk[cc], 8);
+ }
+ if (NE <= 4) {
+ mqk[cc] += simd_shuffle_down(mqk[cc], 4);
+ }
+ if (NE <= 8) {
+ mqk[cc] += simd_shuffle_down(mqk[cc], 2);
+ }
+ if (NE <= 16) {
+ mqk[cc] += simd_shuffle_down(mqk[cc], 1);
+ }
- // simdgroup reduce (NE = 4)
- // [ 0 .. 7] -> [ 0]
- // [ 8 .. 15] -> [ 8]
- // [16 .. 23] -> [16]
- // [24 .. 31] -> [24]
- if (NE <= 1) {
- mqk += simd_shuffle_down(mqk, 16);
- }
- if (NE <= 2) {
- mqk += simd_shuffle_down(mqk, 8);
- }
- if (NE <= 4) {
- mqk += simd_shuffle_down(mqk, 4);
- }
- if (NE <= 8) {
- mqk += simd_shuffle_down(mqk, 2);
+ // broadcast
+ mqk[cc] = simd_shuffle(mqk[cc], NL*ty);
}
- if (NE <= 16) {
- mqk += simd_shuffle_down(mqk, 1);
- }
-
- // mqk = mqk*scale + mask*slope
- if (tx == 0) {
- mqk *= args.scale;
+ }
- if (args.logit_softcap != 0.0f) {
- mqk = args.logit_softcap*precise::tanh(mqk);
- }
+ if (FC_flash_attn_ext_vec_has_mask &&
+ !FC_flash_attn_ext_vec_has_scap &&
+ !FC_flash_attn_ext_vec_has_bias) {
+ ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]);
+ } else {
+ mqk[tx] *= args.scale;
- mqk += sm[NE*cc + ty]*slope;
+ if (FC_flash_attn_ext_vec_has_scap) {
+ mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]);
+ }
- ss[NE*cc + ty] = mqk;
+ if (FC_flash_attn_ext_vec_has_bias) {
+ mqk[tx] += (qk_t) sm[NE*tx + ty]*slope;
+ } else {
+ mqk[tx] += (qk_t) sm[NE*tx + ty];
}
+
+ ss[NE*tx + ty] = mqk[tx];
}
}
ss[tiisg] = vs;
// O = diag(ms)*O
- #pragma unroll(DV4/NL)
- for (short ii = 0; ii < DV4; ii += NL) {
- lo[ii/NL] *= ms;
+ if ((DV4/NL % NW == 0) || ty == 0) {
+ FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+ so4[ii*NL] *= ms;
+ }
}
}
// O = O + (Q*K^T)*V
{
- //#pragma unroll(C/NE)
- for (short cc = 0; cc < C/NE; ++cc) {
- device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21 + ikv2*args.nb22 + ikv3*args.nb23));
+ o4_t lo[DV4/NL];
+ FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+ lo[ii] = 0.0f;
+ }
- const s4_t ms(ss[NE*cc + ty]);
+ if (is_same<vd4_t, v4_t>::value) {
+ device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
- #pragma unroll(DV4/NL)
- for (short ii = 0; ii < DV4; ii += NL) {
- const short i = ii + tx;
+ pv4 += ty*NS20/4 + tx;
- v4_t mv;
- deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
+ const auto sst = ss + ty;
- lo[ii/NL] += o4_t(float4(mv)*float4(ms));
+ FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
+ FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+ lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE]));
+ }
+ }
+ } else {
+ FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
+ device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
+
+ FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+ const short i = ii*NL + tx;
+
+ v4_t mv;
+ deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
+
+ lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty]));
+ }
+ }
+ }
+
+ FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+ if (NE > 1) {
+ lo[ii][0] += simd_shuffle_down(lo[ii][0], 16);
+ lo[ii][1] += simd_shuffle_down(lo[ii][1], 16);
+ lo[ii][2] += simd_shuffle_down(lo[ii][2], 16);
+ lo[ii][3] += simd_shuffle_down(lo[ii][3], 16);
+ }
+
+ if (NE > 2) {
+ lo[ii][0] += simd_shuffle_down(lo[ii][0], 8);
+ lo[ii][1] += simd_shuffle_down(lo[ii][1], 8);
+ lo[ii][2] += simd_shuffle_down(lo[ii][2], 8);
+ lo[ii][3] += simd_shuffle_down(lo[ii][3], 8);
+ }
+
+ if (NE > 4) {
+ lo[ii][0] += simd_shuffle_down(lo[ii][0], 4);
+ lo[ii][1] += simd_shuffle_down(lo[ii][1], 4);
+ lo[ii][2] += simd_shuffle_down(lo[ii][2], 4);
+ lo[ii][3] += simd_shuffle_down(lo[ii][3], 4);
+ }
+
+ if (NE > 8) {
+ lo[ii][0] += simd_shuffle_down(lo[ii][0], 2);
+ lo[ii][1] += simd_shuffle_down(lo[ii][1], 2);
+ lo[ii][2] += simd_shuffle_down(lo[ii][2], 2);
+ lo[ii][3] += simd_shuffle_down(lo[ii][3], 2);
+ }
+
+ if (NE > 16) {
+ lo[ii][0] += simd_shuffle_down(lo[ii][0], 1);
+ lo[ii][1] += simd_shuffle_down(lo[ii][1], 1);
+ lo[ii][2] += simd_shuffle_down(lo[ii][2], 1);
+ lo[ii][3] += simd_shuffle_down(lo[ii][3], 1);
+ }
+ }
+
+ if ((DV4/NL % NW == 0) || ty == 0) {
+ FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+ so4[ii*NL] += lo[ii];
}
}
}
}
- if (sinks != q && sgitg == 0 && iwg == 0) {
+ if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) {
const float m = M;
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
S = S*ms + simd_sum(vs);
-#pragma unroll(DV4/NL)
- for (short ii = 0; ii < DV4; ii += NL) {
- lo[ii/NL] *= ms;
+ if ((DV4/NL % NW == 0) || ty == 0) {
+ FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
+ so4[ii*NL] *= ms;
+ }
}
}
}
}
- // simdgroup reduce (NE = 4)
- // [ 0, 8, 16, 24] -> [ 0]
- // [ 1, 9, 17, 25] -> [ 1]
- // [ 2, 10, 18, 26] -> [ 2]
- // [ 3, 11, 19, 27] -> [ 3]
- // [ 4, 12, 20, 28] -> [ 4]
- // [ 5, 13, 21, 29] -> [ 5]
- // [ 6, 14, 22, 30] -> [ 6]
- // [ 7, 15, 23, 31] -> [ 7]
- for (short ii = 0; ii < DV4; ii += NL) {
- if (NE > 1) {
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
- }
-
- if (NE > 2) {
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
- }
-
- if (NE > 4) {
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
- }
-
- if (NE > 8) {
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
- }
-
- if (NE > 16) {
- lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
- lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
- lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
- lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
- }
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
-
- // store results to shared memory
- for (short i = tiisg; i < DV4; i += NL) {
- sr4[i] = lo[i/NL];
- }
+ so4 -= tiisg;
threadgroup_barrier(mem_flags::mem_threadgroup);
// parallel reduce
- for (short r = nsg/2; r > 0; r >>= 1) {
+ for (short r = NSG/2; r > 0; r >>= 1) {
if (sgitg < r) {
const float S0 = ss[ 0];
const float S1 = ss[r*(SH/2) + 0];
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
for (short i = tiisg; i < DV4; i += NW) {
- sr4[i] = sr4[i]*ms0 + sr4[i + r*DV4]*ms1;
+ so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1;
}
}
const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
device float4 * dst4 = (device float4 *) dst;
- device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results
+ device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
- const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
+ const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
// interleave the workgroup data
for (short i = tiisg; i < DV4; i += NW) {
- dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
+ dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S;
}
// store S and M
- if (nwg > 1 && tiisg == 0) {
- dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
- dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
+ if (NWG > 1) {
+ if (tiisg == 0) {
+ dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0];
+ dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1];
+ }
}
}
+
+#undef NWG
+#undef NS10
+#undef NS20
+}
+
+template<
+ typename q4_t, // query types in shared memory
+ typename k4_t, // key types in shared memory
+ typename v4_t, // value types in shared memory
+ typename qk_t, // Q*K types
+ typename s_t, // soft-max types
+ typename s4_t,
+ typename o4_t, // attention accumulation types
+ typename kd4_t, // key type in device memory
+ short nl_k,
+ void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
+ typename vd4_t, // value type in device memory
+ short nl_v,
+ void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
+ short DK, // K head size
+ short DV, // V head size
+ short NE = 4, // head elements per thread
+ short Q = 1, // queries per threadgroup
+ short C = 32> // cache items per threadgroup
+kernel void kernel_flash_attn_ext_vec(
+ constant ggml_metal_kargs_flash_attn_ext_vec & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device const char * sinks,
+ device char * dst,
+ threadgroup half * shmem_f16 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
+#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
+ switch (FC_flash_attn_ext_vec_nsg) {
+ // note: disabled cases to reduce library load time
+ case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
+ case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
+ case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
+ //case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
+ //case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
+ //case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
+ }
+#undef FWD_TMPL
+#undef FWD_ARGS
}
// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
-template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 8>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 8>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
-template [[host_name("kernel_flash_attn_ext_vec_f16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_f16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk192_hv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 4>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 4>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
-template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
#if defined(GGML_METAL_USE_BF16)
-template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
#endif
-template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
-template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
#undef FA_TYPES
-kernel void kernel_flash_attn_ext_reduce(
- constant ggml_metal_kargs_flash_attn_ext_reduce & args,
+constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];
+constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];
+
+kernel void kernel_flash_attn_ext_vec_reduce(
+ constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args,
device const char * htmp,
device char * dst,
uint tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+#define NWG (FC_flash_attn_ext_vec_reduce_NWG)
+#define DV (FC_flash_attn_ext_vec_reduce_DV)
+
const uint64_t rid = tgpig;
- const short nwg = 32;
const short iwg = tiisg;
- const short DV = args.ne20;
- const short DV4 = DV/4;
- device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
- device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg;
- device float4 * dst4 = (device float4 *) dst + rid*DV4;
+ device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG;
- float S = ss[rid*(2*nwg) + 2*iwg + 0];
- float M = ss[rid*(2*nwg) + 2*iwg + 1];
+ float S = ss[rid*(2*NWG) + 2*iwg + 0];
+ float M = ss[rid*(2*NWG) + 2*iwg + 1];
const float m = simd_max(M);
const float ms = exp(M - m);
S = 1.0f/simd_sum(S*ms);
- for (int i = sgitg; i < DV4; i += nwg) {
- const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
+ const short DV4 = DV/4;
+
+ device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG;
+ device float4 * dst4 = (device float4 *) dst + rid*DV4;
+
+ for (short i = sgitg; i < DV4; i += NWG) {
+ const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms);
if (iwg == 0) {
dst4[i] = v*S;
}
}
+
+#undef NWG
+#undef DV
}
template<typename T>
const int32_t i10 = i01;
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
- device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
+ device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
#pragma unroll(4)
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
+ simdgroup_barrier(mem_flags::mem_none);
+
#pragma unroll(4)
for (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
}
- simdgroup_barrier(mem_flags::mem_none);
-
#pragma unroll(2)
for (short i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
}
+ simdgroup_barrier(mem_flags::mem_none);
+
#pragma unroll(8)
for (short i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);