Moved python files from sw to sw/python; Moved scritps into sw/python/scripts
This commit is contained in:
parent
7c01f0a7e3
commit
3938c4aa31
@ -1,57 +0,0 @@
|
|||||||
import typing
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import seaborn as sns
|
|
||||||
import os
|
|
||||||
from utility import visualization
|
|
||||||
from utility.simulation import SimulationDeSerializer
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This should be the responsibility of the DeSerializer
|
|
||||||
def get_sim_slugs(results_dir: str) -> typing.List[str]:
|
|
||||||
"""Get a list of slugified simulation names."""
|
|
||||||
result_files = [f for f in os.listdir(results_dir) if
|
|
||||||
os.path.isfile(os.path.join(results_dir, f))]
|
|
||||||
|
|
||||||
metadata_files = [f for f in result_files if f.endswith("_metadata.json")]
|
|
||||||
|
|
||||||
sim_slugs = [f.removesuffix("_metadata.json") for f in metadata_files]
|
|
||||||
|
|
||||||
return sim_slugs
|
|
||||||
|
|
||||||
|
|
||||||
def plot_results() -> None:
|
|
||||||
"""Plot the BER curves for all present simulation results."""
|
|
||||||
saves_dir = "sim_saves"
|
|
||||||
results_dir = "sim_results"
|
|
||||||
|
|
||||||
slugs = get_sim_slugs(results_dir)
|
|
||||||
|
|
||||||
deserializer = SimulationDeSerializer(save_dir=saves_dir,
|
|
||||||
results_dir=results_dir)
|
|
||||||
|
|
||||||
# Read data
|
|
||||||
|
|
||||||
data = []
|
|
||||||
for slug in slugs:
|
|
||||||
df, metadata = deserializer.read_results(slug)
|
|
||||||
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
|
|
||||||
|
|
||||||
graph_title = metadata["name"]
|
|
||||||
line_labels = metadata["labels"]
|
|
||||||
|
|
||||||
graph_tuple = (graph_title, df, line_labels)
|
|
||||||
data.append(graph_tuple)
|
|
||||||
|
|
||||||
# Plot results
|
|
||||||
|
|
||||||
sns.set_theme()
|
|
||||||
fig = visualization.plot_BERs(
|
|
||||||
title="Bit-Error-Rates of proximal decoder for different codes",
|
|
||||||
data=data, num_cols=4)
|
|
||||||
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
plot_results()
|
|
||||||
0
sw/python/cpp_modules/__init__.py
Normal file
0
sw/python/cpp_modules/__init__.py
Normal file
BIN
sw/python/cpp_modules/cpp_decoders.cpython-310-x86_64-linux-gnu.so
Executable file
BIN
sw/python/cpp_modules/cpp_decoders.cpython-310-x86_64-linux-gnu.so
Executable file
Binary file not shown.
@ -1,3 +1,9 @@
|
|||||||
|
import sys, os
|
||||||
|
import sys, os
|
||||||
|
sys.path.append(os.path.abspath('../..'))
|
||||||
|
print(sys.path)
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -63,7 +69,7 @@ def get_params(code_name: str):
|
|||||||
"""In this function all parameters for the simulation are defined."""
|
"""In this function all parameters for the simulation are defined."""
|
||||||
# Define global simulation parameters
|
# Define global simulation parameters
|
||||||
|
|
||||||
H_file = f"res/{code_name}.alist"
|
H_file = f"../../res/{code_name}.alist"
|
||||||
|
|
||||||
H = codes.read_alist_file(H_file)
|
H = codes.read_alist_file(H_file)
|
||||||
n_min_k, n = H.shape
|
n_min_k, n = H.shape
|
||||||
@ -1,3 +1,6 @@
|
|||||||
|
import sys, os
|
||||||
|
sys.path.append(os.path.abspath('../..'))
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -47,7 +50,7 @@ def get_params(code_name: str):
|
|||||||
"""In this function all parameters for the simulation are defined."""
|
"""In this function all parameters for the simulation are defined."""
|
||||||
# Define global simulation parameters
|
# Define global simulation parameters
|
||||||
|
|
||||||
H_file = f"res/{code_name}.alist"
|
H_file = f"../../res/{code_name}.alist"
|
||||||
H = codes.read_alist_file(H_file)
|
H = codes.read_alist_file(H_file)
|
||||||
n_min_k, n = H.shape
|
n_min_k, n = H.shape
|
||||||
k = n - n_min_k
|
k = n - n_min_k
|
||||||
@ -1,3 +1,6 @@
|
|||||||
|
import sys, os
|
||||||
|
sys.path.append(os.path.abspath('../..'))
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
@ -14,7 +17,7 @@ from cpp_modules.cpp_decoders import ProximalDecoder_204_102 as ProximalDecoder
|
|||||||
|
|
||||||
|
|
||||||
def simulate(H_file, SNR, omega, K, gamma):
|
def simulate(H_file, SNR, omega, K, gamma):
|
||||||
H = codes.read_alist_file(f"res/{H_file}")
|
H = codes.read_alist_file(f"../../res/{H_file}")
|
||||||
n_min_k, n = H.shape
|
n_min_k, n = H.shape
|
||||||
k = n - n_min_k
|
k = n - n_min_k
|
||||||
|
|
||||||
121
sw/python/utility/simulation/simulators.py
Normal file
121
sw/python/utility/simulation/simulators.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import typing
|
||||||
|
from tqdm import tqdm
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, process, wait
|
||||||
|
from functools import partial
|
||||||
|
from multiprocessing import Lock
|
||||||
|
|
||||||
|
from utility import noise
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Fix ProximalDecoder_Dynamic
|
||||||
|
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as
|
||||||
|
# ProximalDecoder
|
||||||
|
|
||||||
|
|
||||||
|
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
|
||||||
|
"""Count the number of wrong bits in a decoded codeword.
|
||||||
|
|
||||||
|
:param d: Originally sent data
|
||||||
|
:param d_hat: Received data
|
||||||
|
:return: Number of bit errors
|
||||||
|
"""
|
||||||
|
return np.sum(d != d_hat)
|
||||||
|
|
||||||
|
|
||||||
|
class HashableDict:
|
||||||
|
"""Class behaving like an immutable dict. More importantly it is
|
||||||
|
hashable and thus usable as a key type for another dict."""
|
||||||
|
|
||||||
|
def __init__(self, data_dict):
|
||||||
|
assert (isinstance(data_dict, dict))
|
||||||
|
for key, val in data_dict.items():
|
||||||
|
self.__dict__[key] = val
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.__dict__[item]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
class GenericMultithreadedSimulator:
|
||||||
|
def __init__(self, max_workers=8):
|
||||||
|
self._format_func = None
|
||||||
|
self._task_func = None
|
||||||
|
self._task_params = None
|
||||||
|
self._max_workers = max_workers
|
||||||
|
|
||||||
|
self._results = {}
|
||||||
|
self._executor = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_params(self):
|
||||||
|
return self._task_params
|
||||||
|
|
||||||
|
@task_params.setter
|
||||||
|
def task_params(self, sim_params):
|
||||||
|
self._task_params = {HashableDict(iteration_params): iteration_params
|
||||||
|
for iteration_params in sim_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_func(self):
|
||||||
|
return self._task_func
|
||||||
|
|
||||||
|
@task_func.setter
|
||||||
|
def task_func(self, func):
|
||||||
|
self._task_func = func
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format_func(self):
|
||||||
|
return self._format_func
|
||||||
|
|
||||||
|
@format_func.setter
|
||||||
|
def format_func(self, func):
|
||||||
|
self._format_func = func
|
||||||
|
|
||||||
|
def start_or_continue(self):
|
||||||
|
assert self._task_func is not None
|
||||||
|
assert self._task_params is not None
|
||||||
|
assert self._format_func is not None
|
||||||
|
|
||||||
|
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
|
||||||
|
|
||||||
|
with tqdm(total=(len(self._task_params)), leave=False) as pbar:
|
||||||
|
def done_callback(key, f):
|
||||||
|
try:
|
||||||
|
pbar.update(1)
|
||||||
|
self._results[key] = f.result()
|
||||||
|
del self._task_params[key]
|
||||||
|
except process.BrokenProcessPool:
|
||||||
|
# This exception is thrown when the program is
|
||||||
|
# prematurely stopped with a KeyboardInterrupt
|
||||||
|
pass
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
|
||||||
|
for key, params in list(self._task_params.items()):
|
||||||
|
future = self._executor.submit(self._task_func, params)
|
||||||
|
future.add_done_callback(partial(done_callback, key))
|
||||||
|
futures.append(future)
|
||||||
|
|
||||||
|
self._executor.shutdown(wait=True, cancel_futures=False)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
assert self._executor is not None, "The simulation has to be started" \
|
||||||
|
" before it can be stopped"
|
||||||
|
self._executor.shutdown(wait=True, cancel_futures=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_results(self):
|
||||||
|
return self._format_func(self._results)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state["_executor"] = None
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self._executor = ProcessPoolExecutor()
|
||||||
@ -1,361 +0,0 @@
|
|||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import typing
|
|
||||||
from tqdm import tqdm
|
|
||||||
from concurrent.futures import ProcessPoolExecutor, process, wait
|
|
||||||
from functools import partial
|
|
||||||
from multiprocessing import Lock
|
|
||||||
|
|
||||||
from utility import noise
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fix ProximalDecoder_Dynamic
|
|
||||||
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as
|
|
||||||
# ProximalDecoder
|
|
||||||
|
|
||||||
|
|
||||||
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
|
|
||||||
"""Count the number of wrong bits in a decoded codeword.
|
|
||||||
|
|
||||||
:param d: Originally sent data
|
|
||||||
:param d_hat: Received data
|
|
||||||
:return: Number of bit errors
|
|
||||||
"""
|
|
||||||
return np.sum(d != d_hat)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: Write unit tests
|
|
||||||
class ProximalDecoderSimulator:
|
|
||||||
"""Class allowing for saving of simulations state.
|
|
||||||
|
|
||||||
Given a list of decoders, this class allows for simulating the
|
|
||||||
Bit-Error-Rates of each decoder for various SNRs.
|
|
||||||
|
|
||||||
The functionality implemented by this class could be achieved by a bunch
|
|
||||||
of loops and a function. However, storing the state of the simulation as
|
|
||||||
member variables allows for pausing and resuming the simulation at a
|
|
||||||
later time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, n: int, k: int,
|
|
||||||
decoders: typing.Sequence[typing.Any],
|
|
||||||
SNRs: typing.Sequence[float],
|
|
||||||
target_frame_errors: int,
|
|
||||||
max_num_iterations: int):
|
|
||||||
"""Construct and object of type simulator.
|
|
||||||
|
|
||||||
:param n: Number of bits in a codeword
|
|
||||||
:param k: Number of bits in a dataword
|
|
||||||
:param decoders: Sequence of decoders to test
|
|
||||||
:param SNRs: Sequence of SNRs for which the BERs should be calculated
|
|
||||||
:param target_frame_errors: Number of frame errors after which to
|
|
||||||
stop the simulation
|
|
||||||
"""
|
|
||||||
# Simulation parameters
|
|
||||||
|
|
||||||
self._n = n
|
|
||||||
self._k = k
|
|
||||||
self._decoders = decoders
|
|
||||||
self._SNRs = SNRs
|
|
||||||
self._target_frame_errors = target_frame_errors
|
|
||||||
self._max_num_iterations = max_num_iterations
|
|
||||||
|
|
||||||
self._x = np.zeros(self._n)
|
|
||||||
self._x_bpsk = 1 - 2 * self._x # Map x from [0, 1]^n to [-1, 1]^n
|
|
||||||
|
|
||||||
# Simulation state
|
|
||||||
|
|
||||||
self._curr_decoder_index = 0
|
|
||||||
self._curr_SNRs_index = 0
|
|
||||||
|
|
||||||
self._curr_num_frame_errors = 0
|
|
||||||
self._curr_num_bit_errors = 0
|
|
||||||
self._curr_num_iterations = 0
|
|
||||||
self._curr_num_dec_fails = 0
|
|
||||||
|
|
||||||
# Results & Miscellaneous
|
|
||||||
|
|
||||||
self._BERs = [np.zeros(len(SNRs)) for i in range(len(decoders))]
|
|
||||||
self._dec_fails = [np.zeros(len(SNRs)) for i in range(len(decoders))]
|
|
||||||
self._avg_K = [np.zeros(len(SNRs)) for i in range(len(decoders))]
|
|
||||||
|
|
||||||
self._create_pbars()
|
|
||||||
|
|
||||||
self._sim_running = False
|
|
||||||
|
|
||||||
def _create_pbars(self):
|
|
||||||
self._overall_pbar = tqdm(total=len(self._decoders),
|
|
||||||
desc="Calculating the answer to life, "
|
|
||||||
"the universe and everything",
|
|
||||||
leave=False,
|
|
||||||
bar_format="{l_bar}{bar}| {n_fmt}/{"
|
|
||||||
"total_fmt} [{elapsed}]")
|
|
||||||
|
|
||||||
decoder = self._decoders[self._curr_decoder_index]
|
|
||||||
self._decoder_pbar = tqdm(total=len(self._SNRs),
|
|
||||||
desc=f"Calculating"
|
|
||||||
f"g BERs"
|
|
||||||
f" for {decoder.__class__.__name__}",
|
|
||||||
leave=False,
|
|
||||||
bar_format="{l_bar}{bar}| {n_fmt}/{"
|
|
||||||
"total_fmt}")
|
|
||||||
|
|
||||||
self._snr_pbar = tqdm(total=self._max_num_iterations,
|
|
||||||
desc=f"Simulating for SNR = {self._SNRs[0]} dB",
|
|
||||||
leave=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getstate__(self) -> typing.Dict:
|
|
||||||
"""Custom serialization function called by the 'pickle' module
|
|
||||||
when saving the state of a currently running simulation
|
|
||||||
"""
|
|
||||||
state = self.__dict__.copy()
|
|
||||||
del state['_overall_pbar']
|
|
||||||
del state['_decoder_pbar']
|
|
||||||
del state['_snr_pbar']
|
|
||||||
return state
|
|
||||||
|
|
||||||
def __setstate__(self, state) -> None:
|
|
||||||
"""Custom deserialization function called by the 'pickle' module
|
|
||||||
when loading a previously saved simulation
|
|
||||||
|
|
||||||
:param state: Dictionary storing the serialized version of an object
|
|
||||||
of this class
|
|
||||||
"""
|
|
||||||
self.__dict__.update(state)
|
|
||||||
|
|
||||||
self._create_pbars()
|
|
||||||
|
|
||||||
self._overall_pbar.update(self._curr_decoder_index)
|
|
||||||
self._decoder_pbar.update(self._curr_SNRs_index)
|
|
||||||
self._snr_pbar.update(self._curr_num_frame_errors)
|
|
||||||
|
|
||||||
self._overall_pbar.refresh()
|
|
||||||
self._decoder_pbar.refresh()
|
|
||||||
self._snr_pbar.refresh()
|
|
||||||
|
|
||||||
def _simulate_transmission(self) -> int:
|
|
||||||
"""Simulate the transmission of a single codeword.
|
|
||||||
|
|
||||||
:return: Number of bit errors that occurred
|
|
||||||
"""
|
|
||||||
SNR = self._SNRs[self._curr_SNRs_index]
|
|
||||||
decoder = self._decoders[self._curr_decoder_index]
|
|
||||||
|
|
||||||
y = noise.add_awgn(self._x_bpsk, SNR, self._n, self._k)
|
|
||||||
x_hat, K = decoder.decode(y)
|
|
||||||
|
|
||||||
# Handle decoding failure
|
|
||||||
if x_hat is not None:
|
|
||||||
self._avg_K[self._curr_decoder_index][self._curr_SNRs_index] += K
|
|
||||||
return count_bit_errors(self._x, x_hat)
|
|
||||||
else:
|
|
||||||
self._curr_num_dec_fails += 1
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def _update_statistics(self, bit_errors: int) -> None:
|
|
||||||
"""Update the statistics of the simulator.
|
|
||||||
|
|
||||||
:param bit_errors: Number of bit errors that occurred during the
|
|
||||||
last transmission
|
|
||||||
"""
|
|
||||||
self._curr_num_iterations += 1
|
|
||||||
self._snr_pbar.update(1)
|
|
||||||
|
|
||||||
if bit_errors > 0:
|
|
||||||
self._curr_num_frame_errors += 1
|
|
||||||
self._curr_num_bit_errors += bit_errors
|
|
||||||
|
|
||||||
def _advance_state(self) -> None:
|
|
||||||
"""Advance the state of the simulator.
|
|
||||||
|
|
||||||
This function also handles setting the result arrays and progress bars.
|
|
||||||
"""
|
|
||||||
if (self._curr_num_frame_errors >= self._target_frame_errors) or (
|
|
||||||
self._curr_num_iterations > self._max_num_iterations):
|
|
||||||
|
|
||||||
# Adjust the number of iterations to ignore decoding failures
|
|
||||||
adj_num_iterations = self._curr_num_iterations - \
|
|
||||||
self._curr_num_dec_fails
|
|
||||||
|
|
||||||
if adj_num_iterations == 0:
|
|
||||||
self._BERs[self._curr_decoder_index][self._curr_SNRs_index] = 1
|
|
||||||
else:
|
|
||||||
self._BERs[self._curr_decoder_index][self._curr_SNRs_index] \
|
|
||||||
= self._curr_num_bit_errors / (
|
|
||||||
adj_num_iterations * self._n)
|
|
||||||
self._avg_K[self._curr_decoder_index][self._curr_SNRs_index] \
|
|
||||||
= \
|
|
||||||
self._avg_K[self._curr_decoder_index][
|
|
||||||
self._curr_SNRs_index] / adj_num_iterations
|
|
||||||
|
|
||||||
self._dec_fails[self._curr_decoder_index][self._curr_SNRs_index] \
|
|
||||||
= self._curr_num_dec_fails
|
|
||||||
|
|
||||||
self._curr_num_frame_errors = 0
|
|
||||||
self._curr_num_bit_errors = 0
|
|
||||||
self._curr_num_iterations = 0
|
|
||||||
self._curr_num_dec_fails = 0
|
|
||||||
|
|
||||||
if self._curr_SNRs_index < len(self._SNRs) - 1:
|
|
||||||
self._curr_SNRs_index += 1
|
|
||||||
|
|
||||||
self._snr_pbar.reset()
|
|
||||||
self._overall_pbar.refresh()
|
|
||||||
self._snr_pbar.set_description(
|
|
||||||
f"Simulating for SNR = "
|
|
||||||
f"{self._SNRs[self._curr_SNRs_index]} dB")
|
|
||||||
self._decoder_pbar.update(1)
|
|
||||||
else:
|
|
||||||
if self._curr_decoder_index < len(self._decoders) - 1:
|
|
||||||
self._curr_decoder_index += 1
|
|
||||||
self._curr_SNRs_index = 0
|
|
||||||
|
|
||||||
self._decoder_pbar.reset()
|
|
||||||
decoder = self._decoders[self._curr_decoder_index]
|
|
||||||
self._decoder_pbar.set_description(
|
|
||||||
f"Calculating BERs for {decoder.__class__.__name__}")
|
|
||||||
self._overall_pbar.update(1)
|
|
||||||
else:
|
|
||||||
self._sim_running = False
|
|
||||||
|
|
||||||
self._snr_pbar.close()
|
|
||||||
self._decoder_pbar.close()
|
|
||||||
self._overall_pbar.close()
|
|
||||||
|
|
||||||
def start_or_continue(self) -> None:
|
|
||||||
"""Start the simulation.
|
|
||||||
|
|
||||||
This is a blocking call. A call to the stop() function
|
|
||||||
from another thread will stop this function.
|
|
||||||
"""
|
|
||||||
self._sim_running = True
|
|
||||||
|
|
||||||
while self._sim_running:
|
|
||||||
bit_errors = self._simulate_transmission()
|
|
||||||
self._update_statistics(bit_errors)
|
|
||||||
self._advance_state()
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the simulation."""
|
|
||||||
self._sim_running = False
|
|
||||||
|
|
||||||
def get_current_results(self) -> pd.DataFrame:
|
|
||||||
"""Get the current results.
|
|
||||||
|
|
||||||
If the simulation has not yet completed, the BERs which have not yet
|
|
||||||
been calculated are set to 0.
|
|
||||||
|
|
||||||
:return: pandas Dataframe with the columns ["SNR", "BER_1", "BER_2",
|
|
||||||
..., "DecFails_1", "DecFails_2", ...]
|
|
||||||
"""
|
|
||||||
data = {"SNR": np.array(self._SNRs)}
|
|
||||||
|
|
||||||
for i, decoder_BERs in enumerate(self._BERs):
|
|
||||||
data[f"BER_{i}"] = decoder_BERs
|
|
||||||
|
|
||||||
for i, decoder_dec_fails in enumerate(self._dec_fails):
|
|
||||||
data[f"DecFails_{i}"] = decoder_dec_fails
|
|
||||||
|
|
||||||
for i, avg_K in enumerate(self._avg_K):
|
|
||||||
data[f"AvgK_{i}"] = avg_K
|
|
||||||
|
|
||||||
return pd.DataFrame(data)
|
|
||||||
|
|
||||||
|
|
||||||
class HashableDict:
|
|
||||||
"""Class behaving like an immutable dict. More importantly it is
|
|
||||||
hashable and thus usable as a key type for another dict."""
|
|
||||||
|
|
||||||
def __init__(self, data_dict):
|
|
||||||
assert (isinstance(data_dict, dict))
|
|
||||||
for key, val in data_dict.items():
|
|
||||||
self.__dict__[key] = val
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return self.__dict__[item]
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return str(self.__dict__)
|
|
||||||
|
|
||||||
|
|
||||||
class GenericMultithreadedSimulator:
|
|
||||||
def __init__(self, max_workers=8):
|
|
||||||
self._format_func = None
|
|
||||||
self._task_func = None
|
|
||||||
self._task_params = None
|
|
||||||
self._max_workers = max_workers
|
|
||||||
|
|
||||||
self._results = {}
|
|
||||||
self._executor = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def task_params(self):
|
|
||||||
return self._task_params
|
|
||||||
|
|
||||||
@task_params.setter
|
|
||||||
def task_params(self, sim_params):
|
|
||||||
self._task_params = {HashableDict(iteration_params): iteration_params
|
|
||||||
for iteration_params in sim_params}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def task_func(self):
|
|
||||||
return self._task_func
|
|
||||||
|
|
||||||
@task_func.setter
|
|
||||||
def task_func(self, func):
|
|
||||||
self._task_func = func
|
|
||||||
|
|
||||||
@property
|
|
||||||
def format_func(self):
|
|
||||||
return self._format_func
|
|
||||||
|
|
||||||
@format_func.setter
|
|
||||||
def format_func(self, func):
|
|
||||||
self._format_func = func
|
|
||||||
|
|
||||||
def start_or_continue(self):
|
|
||||||
assert self._task_func is not None
|
|
||||||
assert self._task_params is not None
|
|
||||||
assert self._format_func is not None
|
|
||||||
|
|
||||||
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
|
|
||||||
|
|
||||||
with tqdm(total=(len(self._task_params)), leave=False) as pbar:
|
|
||||||
def done_callback(key, f):
|
|
||||||
try:
|
|
||||||
pbar.update(1)
|
|
||||||
self._results[key] = f.result()
|
|
||||||
del self._task_params[key]
|
|
||||||
except process.BrokenProcessPool:
|
|
||||||
# This exception is thrown when the program is
|
|
||||||
# prematurely stopped with a KeyboardInterrupt
|
|
||||||
# TODO: Make sure task_params have not been removed
|
|
||||||
pass
|
|
||||||
|
|
||||||
futures = []
|
|
||||||
|
|
||||||
for key, params in list(self._task_params.items()):
|
|
||||||
future = self._executor.submit(self._task_func, params)
|
|
||||||
future.add_done_callback(partial(done_callback, key))
|
|
||||||
futures.append(future)
|
|
||||||
|
|
||||||
self._executor.shutdown(wait=True, cancel_futures=False)
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
assert self._executor is not None, "The simulation has to be started" \
|
|
||||||
" before it can be stopped"
|
|
||||||
self._executor.shutdown(wait=True, cancel_futures=True)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_results(self):
|
|
||||||
return self._format_func(self._results)
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
state = self.__dict__.copy()
|
|
||||||
state["_executor"] = None
|
|
||||||
return state
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
self.__dict__.update(state)
|
|
||||||
self._executor = ProcessPoolExecutor()
|
|
||||||
Loading…
Reference in New Issue
Block a user