std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
for (int src = 0; src < GGML_MAX_SRC; ++src) {
- prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
+ if (node->src[src]) {
+ prop.src_address[src] = node->src[src]->data;
+ std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
+ std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
+ } else {
+ prop.src_address[src] = nullptr;
+ std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
+ std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
+ }
}
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
* @param graph_node_properties The stored properties of a CANN graph node.
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
*/
-static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+static bool ggml_graph_node_has_matching_properties(
+ ggml_tensor * node,
+ ggml_graph_node_properties * graph_node_properties) {
if (node->data != graph_node_properties->node_address &&
- node->op != GGML_OP_VIEW) {
+ node->op != GGML_OP_VIEW) {
return false;
}
+
if (node->op != graph_node_properties->node_op) {
return false;
}
+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != graph_node_properties->ne[i]) {
return false;
return false;
}
}
+
for (int i = 0; i < GGML_MAX_SRC; i++) {
- if (node->src[i] &&
- node->src[i]->data != graph_node_properties->src_address[i] &&
- node->op != GGML_OP_VIEW
- ) {
- return false;
+ if (node->src[i]) {
+ if (node->src[i]->data != graph_node_properties->src_address[i] &&
+ node->op != GGML_OP_VIEW) {
+ return false;
+ }
+
+ for (int d = 0; d < GGML_MAX_DIMS; d++) {
+ if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
+ return false;
+ }
+ if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
+ return false;
+ }
+ }
+ } else {
+ if (graph_node_properties->src_address[i] != nullptr) {
+ return false;
+ }
}
}
- if (node->op == GGML_OP_SCALE &&
- memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
- return false;
+
+ if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
+ return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;
}