ba-thesis/sw/utility/simulation/simulators.py

347 lines
12 KiB
Python

import pandas as pd
import numpy as np
import typing
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, process, wait
from functools import partial
from multiprocessing import Lock
from utility import noise
# TODO: Fix ProximalDecoder_Dynamic
# from cpp_modules.cpp_decoders import ProximalDecoder_Dynamic as ProximalDecoder
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
"""Count the number of wrong bits in a decoded codeword.
:param d: Originally sent data
:param d_hat: Received data
:return: Number of bit errors
"""
return np.sum(d != d_hat)
# TODO: Write unit tests
class ProximalDecoderSimulator:
"""Class allowing for saving of simulations state.
Given a list of decoders, this class allows for simulating the
Bit-Error-Rates of each decoder for various SNRs.
The functionality implemented by this class could be achieved by a bunch
of loops and a function. However, storing the state of the simulation as
member variables allows for pausing and resuming the simulation at a
later time.
"""
def __init__(self, n: int, k: int,
decoders: typing.Sequence[typing.Any],
SNRs: typing.Sequence[float],
target_frame_errors: int,
max_num_iterations: int):
"""Construct and object of type simulator.
:param n: Number of bits in a codeword
:param k: Number of bits in a dataword
:param decoders: Sequence of decoders to test
:param SNRs: Sequence of SNRs for which the BERs should be calculated
:param target_frame_errors: Number of frame errors after which to
stop the simulation
"""
# Simulation parameters
self._n = n
self._k = k
self._decoders = decoders
self._SNRs = SNRs
self._target_frame_errors = target_frame_errors
self._max_num_iterations = max_num_iterations
self._x = np.zeros(self._n)
self._x_bpsk = 1 - 2 * self._x # Map x from [0, 1]^n to [-1, 1]^n
# Simulation state
self._curr_decoder_index = 0
self._curr_SNRs_index = 0
self._curr_num_frame_errors = 0
self._curr_num_bit_errors = 0
self._curr_num_iterations = 0
self._curr_num_dec_fails = 0
# Results & Miscellaneous
self._BERs = [np.zeros(len(SNRs)) for i in range(len(decoders))]
self._dec_fails = [np.zeros(len(SNRs)) for i in range(len(decoders))]
self._avg_K = [np.zeros(len(SNRs)) for i in range(len(decoders))]
self._create_pbars()
self._sim_running = False
def _create_pbars(self):
self._overall_pbar = tqdm(total=len(self._decoders),
desc="Calculating the answer to life, "
"the universe and everything",
leave=False,
bar_format="{l_bar}{bar}| {n_fmt}/{"
"total_fmt} [{elapsed}]")
decoder = self._decoders[self._curr_decoder_index]
self._decoder_pbar = tqdm(total=len(self._SNRs),
desc=f"Calculating"
f"g BERs"
f" for {decoder.__class__.__name__}",
leave=False,
bar_format="{l_bar}{bar}| {n_fmt}/{"
"total_fmt}")
self._snr_pbar = tqdm(total=self._max_num_iterations,
desc=f"Simulating for SNR = {self._SNRs[0]} dB",
leave=False,
)
def __getstate__(self) -> typing.Dict:
"""Custom serialization function called by the 'pickle' module
when saving the state of a currently running simulation
"""
state = self.__dict__.copy()
del state['_overall_pbar']
del state['_decoder_pbar']
del state['_snr_pbar']
return state
def __setstate__(self, state) -> None:
"""Custom deserialization function called by the 'pickle' module
when loading a previously saved simulation
:param state: Dictionary storing the serialized version of an object
of this class
"""
self.__dict__.update(state)
self._create_pbars()
self._overall_pbar.update(self._curr_decoder_index)
self._decoder_pbar.update(self._curr_SNRs_index)
self._snr_pbar.update(self._curr_num_frame_errors)
self._overall_pbar.refresh()
self._decoder_pbar.refresh()
self._snr_pbar.refresh()
def _simulate_transmission(self) -> int:
"""Simulate the transmission of a single codeword.
:return: Number of bit errors that occurred
"""
SNR = self._SNRs[self._curr_SNRs_index]
decoder = self._decoders[self._curr_decoder_index]
y = noise.add_awgn(self._x_bpsk, SNR, self._n, self._k)
x_hat, K = decoder.decode(y)
# Handle decoding failure
if x_hat is not None:
self._avg_K[self._curr_decoder_index][self._curr_SNRs_index] += K
return count_bit_errors(self._x, x_hat)
else:
self._curr_num_dec_fails += 1
return 0
def _update_statistics(self, bit_errors: int) -> None:
"""Update the statistics of the simulator.
:param bit_errors: Number of bit errors that occurred during the
last transmission
"""
self._curr_num_iterations += 1
self._snr_pbar.update(1)
if bit_errors > 0:
self._curr_num_frame_errors += 1
self._curr_num_bit_errors += bit_errors
def _advance_state(self) -> None:
"""Advance the state of the simulator.
This function also handles setting the result arrays and progress bars.
"""
if (self._curr_num_frame_errors >= self._target_frame_errors) or (
self._curr_num_iterations > self._max_num_iterations):
# Adjust the number of iterations to ignore decoding failures
adj_num_iterations = self._curr_num_iterations - \
self._curr_num_dec_fails
if adj_num_iterations == 0:
self._BERs[self._curr_decoder_index][self._curr_SNRs_index] = 1
else:
self._BERs[self._curr_decoder_index][self._curr_SNRs_index] \
= self._curr_num_bit_errors / (
adj_num_iterations * self._n)
self._avg_K[self._curr_decoder_index][self._curr_SNRs_index] \
= \
self._avg_K[self._curr_decoder_index][
self._curr_SNRs_index] / adj_num_iterations
self._dec_fails[self._curr_decoder_index][self._curr_SNRs_index] \
= self._curr_num_dec_fails
self._curr_num_frame_errors = 0
self._curr_num_bit_errors = 0
self._curr_num_iterations = 0
self._curr_num_dec_fails = 0
if self._curr_SNRs_index < len(self._SNRs) - 1:
self._curr_SNRs_index += 1
self._snr_pbar.reset()
self._overall_pbar.refresh()
self._snr_pbar.set_description(
f"Simulating for SNR = "
f"{self._SNRs[self._curr_SNRs_index]} dB")
self._decoder_pbar.update(1)
else:
if self._curr_decoder_index < len(self._decoders) - 1:
self._curr_decoder_index += 1
self._curr_SNRs_index = 0
self._decoder_pbar.reset()
decoder = self._decoders[self._curr_decoder_index]
self._decoder_pbar.set_description(
f"Calculating BERs for {decoder.__class__.__name__}")
self._overall_pbar.update(1)
else:
self._sim_running = False
self._snr_pbar.close()
self._decoder_pbar.close()
self._overall_pbar.close()
def start_or_continue(self) -> None:
"""Start the simulation.
This is a blocking call. A call to the stop() function
from another thread will stop this function.
"""
self._sim_running = True
while self._sim_running:
bit_errors = self._simulate_transmission()
self._update_statistics(bit_errors)
self._advance_state()
def stop(self) -> None:
"""Stop the simulation."""
self._sim_running = False
def get_current_results(self) -> pd.DataFrame:
"""Get the current results.
If the simulation has not yet completed, the BERs which have not yet
been calculated are set to 0.
:return: pandas Dataframe with the columns ["SNR", "BER_1", "BER_2",
..., "DecFails_1", "DecFails_2", ...]
"""
data = {"SNR": np.array(self._SNRs)}
for i, decoder_BERs in enumerate(self._BERs):
data[f"BER_{i}"] = decoder_BERs
for i, decoder_dec_fails in enumerate(self._dec_fails):
data[f"DecFails_{i}"] = decoder_dec_fails
for i, avg_K in enumerate(self._avg_K):
data[f"AvgK_{i}"] = avg_K
return pd.DataFrame(data)
class HashableDict:
"""Class behaving like an immutable dict. More importantly it is
hashable and thus usable as a key type for another dict."""
def __init__(self, data_dict):
assert (isinstance(data_dict, dict))
for key, val in data_dict.items():
self.__dict__[key] = val
def __getitem__(self, item):
return self.__dict__[item]
def __str__(self):
return str(self.__dict__)
class GenericMultithreadedSimulator:
def __init__(self, max_workers=8):
self._task_func = None
self._task_params = None
self._max_workers = max_workers
self._results = {}
self._executor = None
@property
def task_params(self):
return self._task_params
@task_params.setter
def task_params(self, sim_params):
self._task_params = {HashableDict(iteration_params): iteration_params
for iteration_params in sim_params}
@property
def task_func(self):
return self._task_func
@task_func.setter
def task_func(self, func):
self._task_func = func
def start_or_continue(self):
self._executor = ProcessPoolExecutor(max_workers=self._max_workers)
with tqdm(total=(len(self._task_params)), leave=False) as pbar:
def done_callback(key, f):
try:
pbar.update(1)
self._results[key] = f.result()
del self._task_params[key]
except process.BrokenProcessPool:
# This exception is thrown when the program is
# prematurely stopped with a KeyboardInterrupt
# TODO: Make sure task_params have not been removed
pass
futures = []
for key, params in list(self._task_params.items()):
future = self._executor.submit(self._task_func, params)
future.add_done_callback(partial(done_callback, key))
futures.append(future)
self._executor.shutdown(wait=True, cancel_futures=False)
def stop(self):
assert self._executor is not None, "The simulation has to be started" \
" before it can be stopped"
self._executor.shutdown(wait=False, cancel_futures=True)
@property
def current_results(self):
return self._results
def __getstate__(self):
state = self.__dict__.copy()
state["_executor"] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._executor = ProcessPoolExecutor()