diff --git a/sw/utility/simulation/simulators.py b/sw/utility/simulation/simulators.py index a8f69a6..35443d3 100644 --- a/sw/utility/simulation/simulators.py +++ b/sw/utility/simulation/simulators.py @@ -2,8 +2,12 @@ import pandas as pd import numpy as np import typing from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, process, wait +from functools import partial +from multiprocessing import Lock from utility import noise +from cpp_modules.cpp_decoders import ProximalDecoder def count_bit_errors(d: np.array, d_hat: np.array) -> int: @@ -176,10 +180,10 @@ class ProximalDecoderSimulator: self._BERs[self._curr_decoder_index][self._curr_SNRs_index] \ = self._curr_num_bit_errors / ( adj_num_iterations * self._n) - self._avg_K[self._curr_decoder_index][self._curr_SNRs_index]\ + self._avg_K[self._curr_decoder_index][self._curr_SNRs_index] \ = \ - self._avg_K[self._curr_decoder_index][ - self._curr_SNRs_index] / adj_num_iterations + self._avg_K[self._curr_decoder_index][ + self._curr_SNRs_index] / adj_num_iterations self._dec_fails[self._curr_decoder_index][self._curr_SNRs_index] \ = self._curr_num_dec_fails @@ -215,7 +219,7 @@ class ProximalDecoderSimulator: self._decoder_pbar.close() self._overall_pbar.close() - def start(self) -> None: + def start_or_continue(self) -> None: """Start the simulation. This is a blocking call. A call to the stop() function @@ -252,4 +256,72 @@ class ProximalDecoderSimulator: for i, avg_K in enumerate(self._avg_K): data[f"AvgK_{i}"] = avg_K - return pd.DataFrame(data) \ No newline at end of file + return pd.DataFrame(data) + + +class GenericMultithreadedSimulator: + def __init__(self, max_workers=8): + self._task_func = None + self._task_params = None + self._max_workers = max_workers + + self._results = {} + self._executor = None + + @property + def task_params(self): + return self._task_params + + @task_params.setter + def task_params(self, params): + assert isinstance(params, dict) + self._task_params = params + + @property + def task_func(self): + return self._task_func + + @task_func.setter + def task_func(self, func): + self._task_func = func + + def start_or_continue(self): + self._executor = ProcessPoolExecutor(max_workers=self._max_workers) + + with tqdm(total=(len(self._task_params)), leave=False) as pbar: + def done_callback(key, f): + try: + pbar.update(1) + self._results[key] = f.result() + del self._task_params[key] + except process.BrokenProcessPool: + # This exception is thrown when the program is + # prematurely stopped with a KeyboardInterrupt + # TODO: Make sure task_params have not been removed + pass + + futures = [] + + for key, params in self._task_params.items(): + future = self._executor.submit(self._task_func, params) + future.add_done_callback(partial(done_callback, key)) + futures.append(future) + + self._executor.shutdown(wait=True, cancel_futures=False) + + def stop(self): + assert self._executor is not None, "The simulation has to be started" \ + " before it can be stopped" + self._executor.shutdown(wait=False, cancel_futures=True) + + def get_current_results(self): + return self._results + + def __getstate__(self): + state = self.__dict__.copy() + state["_executor"] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._executor = ProcessPoolExecutor()