cranelift_filetests/
concurrent.rs

1//! Run tests concurrently.
2//!
3//! This module provides the `ConcurrentRunner` struct which uses a pool of threads to run tests
4//! concurrently.
5
6use crate::runone;
7use cranelift_codegen::dbg::LOG_FILENAME_PREFIX;
8use cranelift_codegen::timing;
9use log::error;
10use std::panic::catch_unwind;
11use std::path::{Path, PathBuf};
12use std::sync::mpsc::{channel, Receiver, Sender};
13use std::sync::{Arc, Mutex};
14use std::thread;
15use std::time::Duration;
16
17/// Request sent to worker threads contains jobid and path.
18struct Request(usize, PathBuf);
19
20/// Reply from worker thread,
21pub enum Reply {
22    Starting {
23        jobid: usize,
24    },
25    Done {
26        jobid: usize,
27        result: anyhow::Result<Duration>,
28    },
29    Tick,
30}
31
32/// Manage threads that run test jobs concurrently.
33pub struct ConcurrentRunner {
34    /// Channel for sending requests to the worker threads.
35    /// The workers are sharing the receiver with an `Arc<Mutex<Receiver>>`.
36    /// This is `None` when shutting down.
37    request_tx: Option<Sender<Request>>,
38
39    /// Channel for receiving replies from the workers.
40    /// Workers have their own `Sender`.
41    reply_rx: Receiver<Reply>,
42
43    handles: Vec<thread::JoinHandle<timing::PassTimes>>,
44}
45
46impl ConcurrentRunner {
47    /// Create a new `ConcurrentRunner` with threads spun up.
48    pub fn new() -> Self {
49        let (request_tx, request_rx) = channel();
50        let request_mutex = Arc::new(Mutex::new(request_rx));
51        let (reply_tx, reply_rx) = channel();
52
53        heartbeat_thread(reply_tx.clone());
54
55        let num_threads = std::env::var("CRANELIFT_FILETESTS_THREADS")
56            .ok()
57            .map(|s| {
58                use std::str::FromStr;
59                let n = usize::from_str(&s).unwrap();
60                assert!(n > 0);
61                n
62            })
63            .unwrap_or_else(|| num_cpus::get());
64        let handles = (0..num_threads)
65            .map(|num| worker_thread(num, request_mutex.clone(), reply_tx.clone()))
66            .collect();
67
68        Self {
69            request_tx: Some(request_tx),
70            reply_rx,
71            handles,
72        }
73    }
74
75    /// Shut down worker threads orderly. They will finish any queued jobs first.
76    pub fn shutdown(&mut self) {
77        self.request_tx = None;
78    }
79
80    /// Join all the worker threads.
81    /// Transfer pass timings from the worker threads to the current thread.
82    pub fn join(&mut self) -> timing::PassTimes {
83        assert!(self.request_tx.is_none(), "must shutdown before join");
84        let mut pass_times = timing::PassTimes::default();
85        for h in self.handles.drain(..) {
86            match h.join() {
87                Ok(t) => pass_times.add(&t),
88                Err(e) => println!("worker panicked: {e:?}"),
89            }
90        }
91        pass_times
92    }
93
94    /// Add a new job to the queues.
95    pub fn put(&mut self, jobid: usize, path: &Path) {
96        self.request_tx
97            .as_ref()
98            .expect("cannot push after shutdown")
99            .send(Request(jobid, path.to_owned()))
100            .expect("all the worker threads are gone");
101    }
102
103    /// Get a job reply without blocking.
104    pub fn try_get(&mut self) -> Option<Reply> {
105        self.reply_rx.try_recv().ok()
106    }
107
108    /// Get a job reply, blocking until one is available.
109    pub fn get(&mut self) -> Option<Reply> {
110        self.reply_rx.recv().ok()
111    }
112}
113
114/// Spawn a heartbeat thread which sends ticks down the reply channel every second.
115/// This lets us implement timeouts without the not yet stable `recv_timeout`.
116fn heartbeat_thread(replies: Sender<Reply>) -> thread::JoinHandle<()> {
117    thread::Builder::new()
118        .name("heartbeat".to_string())
119        .spawn(move || {
120            file_per_thread_logger::initialize(LOG_FILENAME_PREFIX);
121            while replies.send(Reply::Tick).is_ok() {
122                thread::sleep(Duration::from_secs(1));
123            }
124        })
125        .unwrap()
126}
127
128/// Spawn a worker thread running tests.
129fn worker_thread(
130    thread_num: usize,
131    requests: Arc<Mutex<Receiver<Request>>>,
132    replies: Sender<Reply>,
133) -> thread::JoinHandle<timing::PassTimes> {
134    thread::Builder::new()
135        .name(format!("worker #{thread_num}"))
136        .spawn(move || {
137            file_per_thread_logger::initialize(LOG_FILENAME_PREFIX);
138            loop {
139                // Lock the mutex only long enough to extract a request.
140                let Request(jobid, path) = match requests.lock().unwrap().recv() {
141                    Err(..) => break, // TX end shut down. exit thread.
142                    Ok(req) => req,
143                };
144
145                // Tell them we're starting this job.
146                // The receiver should always be present for this as long as we have jobs.
147                replies.send(Reply::Starting { jobid }).unwrap();
148
149                let result = catch_unwind(|| runone::run(path.as_path(), None, None))
150                    .unwrap_or_else(|e| {
151                        // The test panicked, leaving us a `Box<Any>`.
152                        // Panics are usually strings.
153                        if let Some(msg) = e.downcast_ref::<String>() {
154                            anyhow::bail!("panicked in worker #{}: {}", thread_num, msg)
155                        } else if let Some(msg) = e.downcast_ref::<&'static str>() {
156                            anyhow::bail!("panicked in worker #{}: {}", thread_num, msg)
157                        } else {
158                            anyhow::bail!("panicked in worker #{}", thread_num)
159                        }
160                    });
161
162                if let Err(ref msg) = result {
163                    error!("FAIL: {}", msg);
164                }
165
166                replies.send(Reply::Done { jobid, result }).unwrap();
167            }
168
169            // Timing is accumulated independently per thread.
170            // Timings from this worker thread will be aggregated by `ConcurrentRunner::join()`.
171            timing::take_current()
172        })
173        .unwrap()
174}