SPH
Pool.cpp
Go to the documentation of this file.
1 #include "thread/Pool.h"
3 
5 
6 SharedPtr<ThreadPool> ThreadPool::globalInstance;
7 
8 struct ThreadContext {
10  ThreadPool* parentPool = nullptr;
11 
13  Size index = Size(-1);
14 
17 };
18 
19 static thread_local ThreadContext threadLocalContext;
20 
21 Task::Task(const Function<void()>& callable)
22  : callable(callable) {}
23 
25  SPH_ASSERT(this->completed());
26 }
27 
28 void Task::wait() {
29  if (threadLocalContext.parentPool) {
30  // worker thread, we can work on tasks
31  while (tasksLeft > 0) {
32  threadLocalContext.parentPool->processTask(false);
33  }
34  } else {
35  std::unique_lock<std::mutex> lock(waitMutex);
36  if (tasksLeft > 0) {
37  // non-worker thread, simply wait until no tasks are left
38  waitVar.wait(lock, [this] { return tasksLeft == 0; });
39  }
40  }
41  SPH_ASSERT(tasksLeft == 0);
42 
43  if (caughtException) {
44  std::rethrow_exception(caughtException);
45  }
46 }
47 
48 bool Task::completed() const {
49  return tasksLeft == 0;
50 }
51 
52 bool Task::isRoot() const {
53  return parent == nullptr;
54 }
55 
57  return parent;
58 }
59 
61  return threadLocalContext.current;
62 }
63 
65  parent = task;
66 
67  // sanity check to avoid circular dependency
68  SPH_ASSERT(!parent || parent->getParent().get() != RawPtr<Task>(this));
69 
70  if (task) {
71  task->addReference();
72  }
73 }
74 
75 void Task::setException(std::exception_ptr exception) {
76  if (this->isRoot()) {
77  caughtException = exception;
78  } else {
79  parent->setException(exception);
80  }
81 }
82 
84  // this may be called from within another task, so we override the threadLocalContext.current for
85  // this scope only
86  SharedPtr<Task> callingTask = threadLocalContext.current;
87  threadLocalContext.current = this->sharedFromThis();
88  auto guard = finally([this, callingTask] {
89  threadLocalContext.current = callingTask;
90  this->removeReference();
91  });
92 
93  try {
94  callable();
95  } catch (...) {
96  // store caught exception, replacing the previous one
97  this->setException(std::current_exception());
98  }
99 }
100 
101 void Task::addReference() {
102  std::unique_lock<std::mutex> lock(waitMutex);
103  SPH_ASSERT(tasksLeft > 0);
104  tasksLeft++;
105 }
106 
107 void Task::removeReference() {
108  std::unique_lock<std::mutex> lock(waitMutex);
109  tasksLeft--;
110  SPH_ASSERT(tasksLeft >= 0);
111 
112  if (tasksLeft == 0) {
113  if (!this->isRoot()) {
114  parent->removeReference();
115  }
116  waitVar.notify_all();
117  }
118 }
119 
120 ThreadPool::ThreadPool(const Size numThreads, const Size granularity)
121  : threads(numThreads == 0 ? std::thread::hardware_concurrency() : numThreads)
122  , granularity(granularity) {
123  SPH_ASSERT(!threads.empty());
124  auto loop = [this](const Size index) {
125  // setup the thread
126  threadLocalContext.parentPool = this;
127  threadLocalContext.index = index;
128 
129  // run the loop
130  while (!stop) {
131  this->processTask(true);
132  }
133  };
134  stop = false;
135  tasksLeft = 0;
136 
137  // start all threads
138  Size index = 0;
139  for (auto& t : threads) {
140  t = makeAuto<std::thread>(loop, index);
141  ++index;
142  }
143 }
144 
146  waitForAll();
147  stop = true;
148  taskVar.notify_all();
149 
150  for (auto& t : threads) {
151  if (t->joinable()) {
152  t->join();
153  }
154  }
155 }
156 
158  SharedPtr<Task> handle = makeShared<Task>(task);
159  handle->setParent(threadLocalContext.current);
160 
161  {
162  std::unique_lock<std::mutex> lock(waitMutex);
163  ++tasksLeft;
164  }
165  {
166  std::unique_lock<std::mutex> lock(taskMutex);
167  tasks.emplace(handle);
168  }
169  taskVar.notify_all();
170  return handle;
171 }
172 
173 void ThreadPool::processTask(const bool wait) {
174  SharedPtr<Task> task = this->getNextTask(wait);
175  if (task) {
176  // run the task
177  task->runAndNotify();
178 
179  std::unique_lock<std::mutex> lock(waitMutex);
180  --tasksLeft;
181  } else {
182  SPH_ASSERT(!wait || stop);
183  }
184  waitVar.notify_one();
185 }
186 
188  std::unique_lock<std::mutex> lock(waitMutex);
189  if (tasksLeft > 0) {
190  waitVar.wait(lock, [this] { return tasksLeft == 0; });
191  }
192  SPH_ASSERT(tasks.empty() && tasksLeft == 0);
193 }
194 
196  return granularity;
197 }
198 
200  if (threadLocalContext.parentPool != this || threadLocalContext.index == Size(-1)) {
201  // thread either belongs to different ThreadPool or isn't a worker thread
202  return NOTHING;
203  }
204  return threadLocalContext.index;
205 }
206 
208  return threads.size();
209 }
210 
212  if (!globalInstance) {
213  globalInstance = makeShared<ThreadPool>();
214  }
215  return globalInstance;
216 }
217 
218 SharedPtr<Task> ThreadPool::getNextTask(const bool wait) {
219  std::unique_lock<std::mutex> lock(taskMutex);
220 
221  if (wait) {
222  // wait till a task is available
223  taskVar.wait(lock, [this] { return !tasks.empty() || stop; });
224  }
225 
226  // remove the task from the queue and return it
227  if (!stop && !tasks.empty()) {
228  SharedPtr<Task> task = tasks.front();
229  tasks.pop();
230  return task;
231  } else {
232  return nullptr;
233  }
234 }
235 
#define SPH_ASSERT(x,...)
Definition: Assert.h:94
NAMESPACE_SPH_BEGIN
Definition: BarnesHut.cpp:13
Wraps a functor and executes it once the wrapper goes out of scope.
uint32_t Size
Integral type used to index arrays (by default).
Definition: Globals.h:16
#define NAMESPACE_SPH_END
Definition: Object.h:12
const NothingType NOTHING
Definition: Optional.h:16
Simple thread pool with fixed number of threads.
INLINE TCounter size() const noexcept
Definition: Array.h:193
INLINE bool empty() const noexcept
Definition: Array.h:201
Non-owning wrapper of pointer.
Definition: RawPtr.h:19
SharedPtr< Task > sharedFromThis() const
Definition: SharedPtr.h:426
INLINE RawPtr< T > get() const
Definition: SharedPtr.h:223
void runAndNotify()
Definition: Pool.cpp:83
virtual bool completed() const override
Checks if the task already finished.
Definition: Pool.cpp:48
~Task()
Definition: Pool.cpp:24
static SharedPtr< Task > getCurrent()
Returns the currently execute task, or nullptr if no task is currently executed on this thread.
Definition: Pool.cpp:60
SharedPtr< Task > getParent() const
Definition: Pool.cpp:56
Task(const Function< void()> &callable)
Definition: Pool.cpp:21
virtual void wait() override
Waits till the task and all the child tasks are completed.
Definition: Pool.cpp:28
bool isRoot() const
Returns true if this is the top-most task.
Definition: Pool.cpp:52
void setException(std::exception_ptr exception)
Saves exception into the task.
Definition: Pool.cpp:75
void setParent(SharedPtr< Task > parent)
Assigns a task that spawned this task.
Definition: Pool.cpp:64
Thread pool capable of executing tasks concurrently.
Definition: Pool.h:70
ThreadPool(const Size numThreads=0, const Size granularity=1000)
Initialize thread pool given the number of threads to use.
Definition: Pool.cpp:120
~ThreadPool()
Definition: Pool.cpp:145
virtual SharedPtr< ITask > submit(const Function< void()> &task) override
Submits a task into the thread pool.
Definition: Pool.cpp:157
static SharedPtr< ThreadPool > getGlobalInstance()
Returns the global instance of the thread pool.
Definition: Pool.cpp:211
void waitForAll()
Blocks until all submitted tasks has been finished.
Definition: Pool.cpp:187
virtual Size getRecommendedGranularity() const override
Returns a value of granularity that is expected to perform well with the current thread count.
Definition: Pool.cpp:195
virtual Optional< Size > getThreadIdx() const override
Returns the index of this thread, or NOTHING if this thread was not invoked by the thread pool.
Definition: Pool.cpp:199
virtual Size getThreadCnt() const override
Returns the number of threads used by this thread pool.
Definition: Pool.cpp:207
Overload of std::swap for Sph::Array.
Definition: Array.h:578
ThreadPool * parentPool
Owner of this thread.
Definition: Pool.cpp:10
Size index
Index of this thread in the parent thread pool (not std::this_thread::get_id() !)
Definition: Pool.cpp:13
SharedPtr< Task > current
Task currently processed by this thread.
Definition: Pool.cpp:16