ACL_MEM_MALLOC_HUGE_FIRST));
acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
- theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+ theta_scale_ne, theta_scale_nb, 1);
float start = 0;
float step = 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void * yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
- theta_scale_nb, GGML_MAX_DIMS);
+ theta_scale_nb, 1);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
GGML_ABORT("CANN error");
}
+// Thread-local variable to record the current device of this thread.
+thread_local int g_current_cann_device = -1;
+
/**
- * @brief Sets the device to be used by CANN.
+ * @brief Set the CANN device to be used.
*
- * @param device The device ID to set.
+ * @param device The target device ID to set.
*/
void ggml_cann_set_device(const int32_t device) {
- int current_device = -1;
- aclrtGetDevice(¤t_device);
+ // int current_device = -1;
+ // Note: In some CANN versions, if no device has been set yet,
+ // aclrtGetDevice(¤t_device) may return 0 by default.
+ // aclrtGetDevice(¤t_device);
- if (device == current_device) {
+ // If the current device is already the target one, no need to switch.
+ if (device == g_current_cann_device) {
return;
}
+
+ // Switch to the new device.
ACL_CHECK(aclrtSetDevice(device));
+
+ // Update the global device record.
+ g_current_cann_device = device;
}
/**