admin 发表于 2022-2-12 14:35:42

rust 实战 - 实现一个线程工作池 ThreadPool


<h1 id="如何实现一个线程池">如何实现一个线程池</h1>
<p>线程池:一种线程使用模式。线程过多会带来调度开销,进而影响缓存局部性和整体性能。而线程池维护着多个线程,等待着监督管理者分配可并发执行的任务。这避免了在处理短时间任务时创建与销毁线程的代价。线程池不仅能够保证内核的充分利用,还能防止过分调度。可用线程数量应该取决于可用的并发处理器、处理器内核、内存、网络sockets等的数量。 例如,对于计算密集型任务,线程数一般取cpu数量+2比较合适,线程数过多会导致额外的线程切换开销。</p>
<p>如何定义线程池Pool呢,首先最大线程数量肯定要作为线程池的一个属性,并且在new Pool时创建指定的线程。</p>
<p>线程池Pool</p>
<pre><code>pub struct Pool {
max_workers: usize, // 定义最大线程数
}

impl Pool {
fn new(max_workers: usize) -&gt; Pool {}
fn execute&lt;F&gt;(&amp;self, f:F) where F: FnOnce() + 'static + Send {}
}

</code></pre>
<p>用<code>execute</code>来执行任务,<code>F: FnOnce() + 'static + Send</code> 是使用thread::spawn线程执行需要满足的trait, 代表F是一个能在线程里执行的闭包函数。</p>
<p>另一点自然而然会想到在Pool添加一个线程数组, 这个线程数组就是用来执行任务的。比如<code>Vec&lt;Thread&gt;</code> balabala。这里的线程是活的,是一个个不断接受任务然后执行的实体。<br>
可以看作在一个线程里不断执行获取任务并执行的Worker。</p>
<pre><code>struct Worker where
{
    _id: usize, // worker 编号
}
</code></pre>
<p>要怎么把任务发送给Worker执行呢?mpsc(multi producer single consumer) 多生产者单消费者可以满足我们的需求,<code>let (tx, rx) = mpsc::channel()</code> 可以获取到一对发送端和接收端。<br>
把发送端添加到Pool里面,把接收端添加到Worker里面。Pool通过channel将任务发送给多个worker消费执行。</p>
<p><strong>这里有一点需要特别注意,channel的接收端receiver需要安全的在多个线程间共享</strong>,因此需要用<code>Arc&lt;Mutex::&lt;T&gt;&gt;</code>来包裹起来,也就是用锁来解决并发冲突。</p>
<p>Pool的完整定义</p>
<pre><code>pub struct Pool {
    workers: Vec&lt;Worker&gt;,
    max_workers: usize,
    sender: mpsc::Sender&lt;Message&gt;
}
</code></pre>
<p>该是时候定义我们要发给Worker的消息Message了<br>
定义如下的枚举值</p>
<pre><code>type Job = Box&lt;dyn FnOnce() + 'static + Send&gt;;
enum Message {
    ByeBye,
    NewJob(Job),
}
</code></pre>
<p>Job是一个要发送给Worker执行的闭包函数,这里ByeBye用来通知Worker可以终止当前的执行,退出线程。</p>
<p>只剩下实现Worker和Pool的具体逻辑了。</p>
<p>Worker的实现</p>
<pre><code>impl Worker
{
    fn new(id: usize, receiver: Arc::&lt;Mutex&lt;mpsc::Receiver&lt;Message&gt;&gt;&gt;) -&gt; Worker {
      let t = thread::spawn( move || {
            loop {
                let receiver = receiver.lock().unwrap();
                let message=receiver.recv().unwrap();
                match message {
                  Message::NewJob(job) =&gt; {
                        println!("do job from worker[{}]", id);
                        job();
                  },
                  Message::ByeBye =&gt; {
                        println!("ByeBye from worker[{}]", id);
                        break
                  },
                }
            }
      });

      Worker {
            _id: id,
            t: Some(t),
      }
    }
}
</code></pre>
<p><strong>let message = receiver.lock().unwrap().recv().unwrap();</strong> 这里获取锁后从receiver获取到消息体,然后let message结束后rust的生命周期会自动释放掉锁。<br>
但如果写成</p>
<pre><code>while let message = receiver.lock().unwrap().recv().unwrap() {
};
</code></pre>
<p>while let 后面整个括号都是一个作用域,要在这个作用域结束后,锁才会释放,比上面let message要锁定久时间。<br>
rust的mutex锁没有对应的unlock方法,由mutex的生命周期管理。</p>
<p>我们给Pool实现<code>Drop</code> trait, 让Pool被销毁时,自动暂停掉worker线程的执行。</p>
<pre><code>impl Drop for Pool {
    fn drop(&amp;mut self) {
      for _ in 0..self.max_workers {
            self.sender.send(Message::ByeBye).unwrap();
      }
      for w in self.workers.iter_mut() {
            if let Some(t) = w.t.take() {
                t.join().unwrap();
            }
      }
    }
}

</code></pre>
<p><strong>drop方法里面用了两个循环</strong>,而不是在一个循环里做完两件事?</p>
<pre><code>for w in self.workers.iter_mut() {
    if let Some(t) = w.t.take() {
      self.sender.send(Message::ByeBye).unwrap();
      t.join().unwrap();
    }
}

</code></pre>
<p>这里面隐藏了一个会造成死锁的陷阱,比如两个Worker, 在单个循环里面迭代所有Worker,再将终止信息发送给通道后,直接调用join,<br>
我们预期是第一个worker要收到消息,并且等他执行完。当情况可能是第二个worker获取到了消息,第一个worker没有获取到,那接下来的join就会阻塞造成死锁。</p>
<p><strong>注意到没有,Worker是被包装在Option内的</strong>,这里有两个点需要注意</p>
<ol>
<li>t.join 需要持有t的所有权</li>
<li>在我们这种情况下,self.workers只能作为引用被for循环迭代。</li>
</ol>
<p>这里考虑让Worker持有<code>Option&lt;JoinHandle&lt;()&gt;&gt;</code>,后续可以通过在Option上调用take方法将Some变体的值移出来,并在原来的位置留下None变体。<br>
换而言之,让运行中的worker持有Some的变体,清理worker时,可以使用None替换掉Some,从而让Worker失去可以运行的线程</p>
<pre><code>struct Worker where
{
    _id: usize,
    t: Option&lt;JoinHandle&lt;()&gt;&gt;,
}
</code></pre>
<h1 id="要点总结">要点总结</h1>
<ul>
<li>Mutex依赖于生命周期管理锁的释放,使用的时候需要注意是否逾期持有锁</li>
<li><code>Vec&lt;Option&lt;T&gt;&gt;</code> 可以解决某些情况下需要T所有权的场景</li>
</ul>
<h1 id="完整代码">完整代码</h1>
<pre><code>use std::thread::{self, JoinHandle};
use std::sync::{Arc, mpsc, Mutex};


type Job = Box&lt;dyn FnOnce() + 'static + Send&gt;;
enum Message {
    ByeBye,
    NewJob(Job),
}

struct Worker where
{
    _id: usize,
    t: Option&lt;JoinHandle&lt;()&gt;&gt;,
}

impl Worker
{
    fn new(id: usize, receiver: Arc::&lt;Mutex&lt;mpsc::Receiver&lt;Message&gt;&gt;&gt;) -&gt; Worker {
      let t = thread::spawn( move || {
            loop {
                let message = receiver.lock().unwrap().recv().unwrap();
                match message {
                  Message::NewJob(job) =&gt; {
                        println!("do job from worker[{}]", id);
                        job();
                  },
                  Message::ByeBye =&gt; {
                        println!("ByeBye from worker[{}]", id);
                        break
                  },
                }
            }
      });

      Worker {
            _id: id,
            t: Some(t),
      }
    }
}

pub struct Pool {
    workers: Vec&lt;Worker&gt;,
    max_workers: usize,
    sender: mpsc::Sender&lt;Message&gt;
}

impl Pool where {
    pub fn new(max_workers: usize) -&gt; Pool {
      if max_workers == 0 {
            panic!("max_workers must be greater than zero!")
      }
      let (tx, rx) = mpsc::channel();

      let mut workers = Vec::with_capacity(max_workers);
      let receiver = Arc::new(Mutex::new(rx));
      for i in 0..max_workers {
            workers.push(Worker::new(i, Arc::clone(&amp;receiver)));
      }

      Pool { workers: workers, max_workers: max_workers, sender: tx }
    }
   
    pub fn execute&lt;F&gt;(&amp;self, f:F) where F: FnOnce() + 'static + Send
    {

      let job = Message::NewJob(Box::new(f));
      self.sender.send(job).unwrap();
    }
}

impl Drop for Pool {
    fn drop(&amp;mut self) {
      for _ in 0..self.max_workers {
            self.sender.send(Message::ByeBye).unwrap();
      }
      for w in self.workers {
            if let Some(t) = w.t.take() {
                t.join().unwrap();
            }
      }
    }
}


#
mod tests {
    use super::*;
    #
    fn it_works() {
      let p = Pool::new(4);
      p.execute(|| println!("do new job1"));
      p.execute(|| println!("do new job2"));
      p.execute(|| println!("do new job3"));
      p.execute(|| println!("do new job4"));
    }
}
</code></pre>

页: [1]
查看完整版本: rust 实战 - 实现一个线程工作池 ThreadPool