ba-thesis/sw/utility/simulation.py

343 lines
12 KiB
Python

"""This file contains utility functions relating to tests and simulations of the decoders."""
import time
import numpy as np
import typing
from tqdm import tqdm
from timeit import default_timer
import signal
from dataclasses import dataclass
import pickle
import os.path
from pathlib import Path
import spdlog as spd
from utility import noise
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
"""Count the number of wrong bits in a decoded codeword.
:param d: Originally sent data
:param d_hat: Received data
:return: Number of bit errors
"""
return np.sum(d != d_hat)
# def test_decoder(n: int,
# k: int,
# decoder: typing.Any,
# SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
# target_frame_errors: int = 100) \
# -> typing.Tuple[np.array, np.array]:
# """Calculate the Bit Error Rate (BER) for a given decoder for a number of SNRs.
#
# This function assumes the all-zeros assumption holds. Progress is printed to stdout.
#
# :param n: Length of a codeword of the used code
# :param k: Length of a dataword of the used code
# :param decoder: Instance of the decoder to be tested
# :param SNRs: List of SNRs for which the BER should be calculated
# :param target_frame_errors: Number of frame errors after which to stop the simulation
# :param N_max: Maximum number of iterations to perform for each SNR
# :return: Tuple of numpy arrays of the form (SNRs, BERs)
# """
#
# x = np.zeros(n)
# x_bpsk = 1 - 2 * x # Map x from [0, 1]^n to [-1, 1]^n
#
# BERs = []
# for SNR in tqdm(SNRs,
# desc=f"Calculating BERs for {decoder.__class__.__name__}",
# position=1,
# leave=False,
# bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
#
# total_bit_errors = 0
# total_bits = 0
# total_frame_errors = 0
#
# pbar = tqdm(total=target_frame_errors,
# desc=f"Simulating for SNR = {SNR} dB",
# position=2,
# leave=False,
# bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]")
#
# while total_frame_errors < target_frame_errors:
# # Simulate channel
# y = noise.add_awgn(x_bpsk, SNR, n, k)
#
# # Decode received frame
# x_hat = decoder.decode(y)
#
# # Calculate statistics
# bit_errors = count_bit_errors(x, x_hat)
# total_bits += x.size
#
# if bit_errors > 0:
# total_frame_errors += 1
# total_bit_errors += bit_errors
# pbar.update(1)
#
# pbar.close()
#
# BERs.append(total_bit_errors / total_bits)
#
# return np.array(SNRs), np.array(BERs)
#
#
# def test_decoders(n: int,
# k: int,
# decoders: typing.List,
# SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
# target_frame_errors: int = 100) \
# -> typing.Tuple[np.array, np.array]:
# """Calculate the Bit Error Rate (BER) for a number of given decoders for a number of SNRs.
#
# This function assumes the all-zeros assumption holds. Progress is printed to stdout.
#
# :param n: Length of a codeword of the used code
# :param k: Length of a dataword of the used code
# :param decoders: List of decoder objects to be tested
# :param SNRs: List of SNRs for which the BER should be calculated
# :param target_frame_errors: Number of frame errors after which to stop the simulation
# :return: Tuple of the form (SNRs, [BERs_1, BERs_2, ...]) where SNR and BERs_x are numpy arrays
# """
# result_BERs = []
#
# start_time = default_timer()
#
# for decoder in tqdm(decoders,
# desc="Calculating the answer to life, the universe and everything",
# position=0,
# leave=False,
# bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
# _, BERs = test_decoder(n, k, decoder, SNRs, target_frame_errors)
# result_BERs.append(BERs)
#
# end_time = default_timer()
# print(f"Elapsed time: {end_time - start_time:.2f}s")
#
# return SNRs, result_BERs
#
@dataclass
class SimulationParameters:
n: int
k: int
decoders: typing.Sequence[typing.Any]
SNRs: typing.Sequence[float]
target_frame_errors: int
@dataclass
class SimulationState:
"""Data structure storing the state of the simulation."""
num_frame_errors: int = 0
num_bit_errors: int = 0
num_total_bits: int = 0
# simulation_time: float = 0
current_decoder_index = 0
current_SNRs_index: int = 0
# TODO: Make more generic
# TODO: Remove save data after successful execution
class SimulationManager:
def __init__(self, save_dir: str, results_dir: str):
self._save_dir = save_dir
self._sim_parameters_filepath = f"{self._save_dir}/sim_parameters.pickle"
self._sim_state_filepath = f"{self._save_dir}/sim_state.pickle"
self._logs_filepath = f"{self._save_dir}/logs.txt"
self._results_dir = results_dir
# TODO: Should the be none or SimulationParameters() and SimulationState() respectively?
self._sim_params = None
self._sim_state = None
self._sim_running = False
Path(self._save_dir).mkdir(parents=True, exist_ok=True)
self._logger = spd.FileLogger("SimulationManager", self._logs_filepath)
self._logger.set_level(spd.LogLevel.DEBUG)
signal.signal(signal.SIGINT, self._exit_gracefully)
signal.signal(signal.SIGTERM, self._exit_gracefully)
#
# Functions relating to the pausing and restarting of simulations
#
def unfinished_simulation_present(self) -> bool:
return os.path.isfile(self._sim_parameters_filepath) \
and os.path.isfile(self._sim_state_filepath)
def continue_unfinished(self):
assert self.unfinished_simulation_present()
with open(self._sim_parameters_filepath, "rb") as file:
self._sim_params = pickle.load(file)
with open(self._sim_state_filepath, "rb") as file:
self._sim_state = pickle.load(file)
self._logger.info("Loaded saved simulation state")
self.start()
# TODO: Make sure old state is overwritten
def _save_state(self):
with open(self._sim_parameters_filepath, "wb") as file:
pickle.dump(self._sim_params, file)
with open(self._sim_state_filepath, "wb") as file:
pickle.dump(self._sim_state, file)
self._logger.info("Saved simulation state")
def _exit_gracefully(self, *args):
self._logger.debug("Intercepted signal SIGINT/SIGTERM")
self._sim_running = False
if (self._sim_params is not None) and (self._sim_state is not None):
self._save_state()
#
# Functions responsible for the actual simulation
#
# def test_decoders(self,
# n: int,
# k: int,
# decoders: typing.Sequence[typing.Any],
# SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
# target_frame_errors: int = 100):
# """Calculate the Bit Error Rate (BER) for a number of given decoders for a number of SNRs.
#
# This function assumes the all-zeros assumption holds. Progress is printed to stdout.
#
# :param n: Length of a codeword of the used code
# :param k: Length of a dataword of the used code
# :param decoders: List of decoder objects to be tested
# :param SNRs: List of SNRs for which the BER should be calculated
# :param target_frame_errors: Number of frame errors after which to stop the simulation
# :return: Tuple of the form (SNRs, [BERs_1, BERs_2, ...]) where SNR and BERs_x are numpy arrays
# """
# # TODO
#
# # Save simulation
# self._sim_parameters = SimulationMetaData(n, k, decoders, SNRs, target_frame_errors)
# self._sim_state = SimulationState()
#
# self._logger.info("Initialized new simulation state")
#
# # Simulation
#
# result_BERs = []
#
# start_time = default_timer()
#
# for decoder in tqdm(decoders,
# desc="Calculating the answer to life, the universe and everything",
# position=0,
# leave=False,
# bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
# _, BERs = self.test_decoder(n, k, decoder, SNRs, target_frame_errors)
# result_BERs.append(BERs)
#
# end_time = default_timer()
# print(f"Elapsed time: {end_time - start_time:.2f}s")
#
# return SNRs, result_BERs
# def test_decoder(self,
# n: int,
# k: int,
# decoder: typing.Any,
# SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
# target_frame_errors: int = 100) \
# -> typing.Tuple[np.array, np.array]:
def start(self):
self._sim_running = True # TODO: Move this somewhere else
decoder = self._sim_params.decoders[self._sim_state.current_decoder_index]
x = np.zeros(self._sim_params.n)
x_bpsk = 1 - 2 * x # Map x from [0, 1]^n to [-1, 1]^n
BERs = []
for SNR in tqdm(self._sim_params.SNRs[self._sim_state.current_SNRs_index:],
desc=f"Calculating BERs for {decoder.__class__.__name__}",
position=1,
leave=False,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
pbar = tqdm(total=self._sim_params.target_frame_errors,
desc=f"Simulating for SNR = {SNR} dB",
position=2,
leave=False,
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]")
pbar.update(self._sim_state.num_frame_errors)
while self._sim_state.num_frame_errors < self._sim_params.target_frame_errors:
if not self._sim_running:
return
# Simulate channel
y = noise.add_awgn(x_bpsk, SNR, self._sim_params.n, self._sim_params.k)
# Decode received frame
x_hat = decoder.decode(y)
# Calculate statistics
bit_errors = count_bit_errors(x, x_hat)
self._sim_state.num_total_bits += x.size
if bit_errors > 0:
self._sim_state.num_frame_errors += 1
self._sim_state.num_bit_errors += bit_errors
pbar.update(1)
# TODO: Load BERs from file as well
BERs.append(self._sim_state.num_bit_errors / self._sim_state.num_total_bits)
pbar.close()
self._sim_state.current_SNRs_index += 1
self._sim_state.num_frame_errors = 0
self._sim_state.num_bit_errors = 0
self._sim_state.num_total_bits = 0
# return np.array(self._sim_params.SNRs), np.array(BERs)
class DecoderTester:
"""Class used to test decoders simulating BPSK modulation and an AWGN channel.
Allows for recovering a stopped simulation if its previous state is known.
"""
def __init__(self, initial_sim_state: SimulationState = SimulationState()):
"""Construct a DecoderTester object.
:param initial_sim_state: State the simulation should start from
"""
self._state = initial_sim_state
def get_state(self) -> SimulationState:
return self._state
def configure(self, n: int,
k: int,
SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
target_frame_errors: int = 100):
pass
def start_test(self, decoders: typing.List):
pass
def stop(self):
pass