1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
|
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index da8813fd2789..86c30dcaaae2 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -2452,6 +2452,8 @@ extern "C" {
uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
bool strict_cpu; // strict cpu placement
bool paused; // start in paused state
+ void (*thread_create_callback)(void); // callback invoked when thread is created
+ void (*thread_destroy_callback)(void); // callback invoked when thread is destroyed
};
struct ggml_threadpool; // forward declaration, see ggml.c
diff --git a/ggml/src/ggml-cpu/ggml-cpu-c.c b/ggml/src/ggml-cpu/ggml-cpu-c.c
index f6bea3df34a0..fbefc0f45ec1 100644
--- a/ggml/src/ggml-cpu/ggml-cpu-c.c
+++ b/ggml/src/ggml-cpu/ggml-cpu-c.c
@@ -463,6 +463,9 @@ struct ggml_threadpool {
int32_t prio; // Scheduling priority
uint32_t poll; // Polling level (0 - no polling)
+ void (*thread_create_callback)(void);
+ void (*thread_destroy_callback)(void);
+
enum ggml_status ec;
};
@@ -2959,6 +2962,10 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
struct ggml_threadpool * threadpool = state->threadpool;
+ if (threadpool->thread_create_callback) {
+ threadpool->thread_create_callback();
+ }
+
ggml_thread_apply_priority(threadpool->prio);
if (ggml_thread_cpumask_is_valid(state->cpumask)) {
ggml_thread_apply_affinity(state->cpumask);
@@ -2990,6 +2997,10 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) {
}
}
+ if (threadpool->thread_destroy_callback) {
+ threadpool->thread_destroy_callback();
+ }
+
return (thread_ret_t) 0;
}
@@ -3049,6 +3060,8 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
threadpool->n_threads_cur = tpp->n_threads;
threadpool->poll = tpp->poll;
threadpool->prio = tpp->prio;
+ threadpool->thread_create_callback = tpp->thread_create_callback;
+ threadpool->thread_destroy_callback = tpp->thread_destroy_callback;
threadpool->ec = GGML_STATUS_SUCCESS;
}
|