]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : fix asserts for setThreadgroupMemoryLength (close #1435)
authorGeorgi Gerganov <redacted>
Tue, 7 Nov 2023 09:02:16 +0000 (11:02 +0200)
committerGeorgi Gerganov <redacted>
Tue, 7 Nov 2023 09:02:16 +0000 (11:02 +0200)
ggml-metal.m

index 43d0dff09ecb627638318789fc003da21ecee0a0..3bee83970b4c3d0e0d9bff3077e9af4d118c8101 100644 (file)
@@ -1030,7 +1030,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1342,7 +1342,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
@@ -1361,7 +1361,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
                             [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
                             [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);