Renamed utility.simulations to utility.simulation; Implemented first version of the SimulationManager class
This commit is contained in:
parent
3e02dcf17c
commit
ffcce7b3f2
@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from utility import simulations, noise, codes
|
||||
from utility import simulation, noise, codes
|
||||
|
||||
|
||||
class CountBitErrorsTestCase(unittest.TestCase):
|
||||
@ -17,9 +17,9 @@ class CountBitErrorsTestCase(unittest.TestCase):
|
||||
d3 = np.array([0, 0, 0, 0])
|
||||
y_hat3 = np.array([1, 1, 1, 1])
|
||||
|
||||
self.assertEqual(simulations.count_bit_errors(d1, y_hat1), 2)
|
||||
self.assertEqual(simulations.count_bit_errors(d2, y_hat2), 0)
|
||||
self.assertEqual(simulations.count_bit_errors(d3, y_hat3), 4)
|
||||
self.assertEqual(simulation.count_bit_errors(d1, y_hat1), 2)
|
||||
self.assertEqual(simulation.count_bit_errors(d2, y_hat2), 0)
|
||||
self.assertEqual(simulation.count_bit_errors(d3, y_hat3), 4)
|
||||
|
||||
|
||||
# TODO: Rewrite tests for new SNR calculation
|
||||
|
||||
342
sw/utility/simulation.py
Normal file
342
sw/utility/simulation.py
Normal file
@ -0,0 +1,342 @@
|
||||
"""This file contains utility functions relating to tests and simulations of the decoders."""
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import typing
|
||||
from tqdm import tqdm
|
||||
from timeit import default_timer
|
||||
import signal
|
||||
from dataclasses import dataclass
|
||||
import pickle
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
import spdlog as spd
|
||||
|
||||
from utility import noise
|
||||
|
||||
|
||||
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
|
||||
"""Count the number of wrong bits in a decoded codeword.
|
||||
|
||||
:param d: Originally sent data
|
||||
:param d_hat: Received data
|
||||
:return: Number of bit errors
|
||||
"""
|
||||
return np.sum(d != d_hat)
|
||||
|
||||
|
||||
# def test_decoder(n: int,
|
||||
# k: int,
|
||||
# decoder: typing.Any,
|
||||
# SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
|
||||
# target_frame_errors: int = 100) \
|
||||
# -> typing.Tuple[np.array, np.array]:
|
||||
# """Calculate the Bit Error Rate (BER) for a given decoder for a number of SNRs.
|
||||
#
|
||||
# This function assumes the all-zeros assumption holds. Progress is printed to stdout.
|
||||
#
|
||||
# :param n: Length of a codeword of the used code
|
||||
# :param k: Length of a dataword of the used code
|
||||
# :param decoder: Instance of the decoder to be tested
|
||||
# :param SNRs: List of SNRs for which the BER should be calculated
|
||||
# :param target_frame_errors: Number of frame errors after which to stop the simulation
|
||||
# :param N_max: Maximum number of iterations to perform for each SNR
|
||||
# :return: Tuple of numpy arrays of the form (SNRs, BERs)
|
||||
# """
|
||||
#
|
||||
# x = np.zeros(n)
|
||||
# x_bpsk = 1 - 2 * x # Map x from [0, 1]^n to [-1, 1]^n
|
||||
#
|
||||
# BERs = []
|
||||
# for SNR in tqdm(SNRs,
|
||||
# desc=f"Calculating BERs for {decoder.__class__.__name__}",
|
||||
# position=1,
|
||||
# leave=False,
|
||||
# bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
|
||||
#
|
||||
# total_bit_errors = 0
|
||||
# total_bits = 0
|
||||
# total_frame_errors = 0
|
||||
#
|
||||
# pbar = tqdm(total=target_frame_errors,
|
||||
# desc=f"Simulating for SNR = {SNR} dB",
|
||||
# position=2,
|
||||
# leave=False,
|
||||
# bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]")
|
||||
#
|
||||
# while total_frame_errors < target_frame_errors:
|
||||
# # Simulate channel
|
||||
# y = noise.add_awgn(x_bpsk, SNR, n, k)
|
||||
#
|
||||
# # Decode received frame
|
||||
# x_hat = decoder.decode(y)
|
||||
#
|
||||
# # Calculate statistics
|
||||
# bit_errors = count_bit_errors(x, x_hat)
|
||||
# total_bits += x.size
|
||||
#
|
||||
# if bit_errors > 0:
|
||||
# total_frame_errors += 1
|
||||
# total_bit_errors += bit_errors
|
||||
# pbar.update(1)
|
||||
#
|
||||
# pbar.close()
|
||||
#
|
||||
# BERs.append(total_bit_errors / total_bits)
|
||||
#
|
||||
# return np.array(SNRs), np.array(BERs)
|
||||
#
|
||||
#
|
||||
# def test_decoders(n: int,
|
||||
# k: int,
|
||||
# decoders: typing.List,
|
||||
# SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
|
||||
# target_frame_errors: int = 100) \
|
||||
# -> typing.Tuple[np.array, np.array]:
|
||||
# """Calculate the Bit Error Rate (BER) for a number of given decoders for a number of SNRs.
|
||||
#
|
||||
# This function assumes the all-zeros assumption holds. Progress is printed to stdout.
|
||||
#
|
||||
# :param n: Length of a codeword of the used code
|
||||
# :param k: Length of a dataword of the used code
|
||||
# :param decoders: List of decoder objects to be tested
|
||||
# :param SNRs: List of SNRs for which the BER should be calculated
|
||||
# :param target_frame_errors: Number of frame errors after which to stop the simulation
|
||||
# :return: Tuple of the form (SNRs, [BERs_1, BERs_2, ...]) where SNR and BERs_x are numpy arrays
|
||||
# """
|
||||
# result_BERs = []
|
||||
#
|
||||
# start_time = default_timer()
|
||||
#
|
||||
# for decoder in tqdm(decoders,
|
||||
# desc="Calculating the answer to life, the universe and everything",
|
||||
# position=0,
|
||||
# leave=False,
|
||||
# bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
|
||||
# _, BERs = test_decoder(n, k, decoder, SNRs, target_frame_errors)
|
||||
# result_BERs.append(BERs)
|
||||
#
|
||||
# end_time = default_timer()
|
||||
# print(f"Elapsed time: {end_time - start_time:.2f}s")
|
||||
#
|
||||
# return SNRs, result_BERs
|
||||
#
|
||||
|
||||
@dataclass
|
||||
class SimulationParameters:
|
||||
n: int
|
||||
k: int
|
||||
decoders: typing.Sequence[typing.Any]
|
||||
SNRs: typing.Sequence[float]
|
||||
target_frame_errors: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationState:
|
||||
"""Data structure storing the state of the simulation."""
|
||||
num_frame_errors: int = 0
|
||||
num_bit_errors: int = 0
|
||||
num_total_bits: int = 0
|
||||
# simulation_time: float = 0
|
||||
|
||||
current_decoder_index = 0
|
||||
current_SNRs_index: int = 0
|
||||
|
||||
|
||||
# TODO: Make more generic
|
||||
# TODO: Remove save data after successful execution
|
||||
class SimulationManager:
|
||||
def __init__(self, save_dir: str, results_dir: str):
|
||||
self._save_dir = save_dir
|
||||
self._sim_parameters_filepath = f"{self._save_dir}/sim_parameters.pickle"
|
||||
self._sim_state_filepath = f"{self._save_dir}/sim_state.pickle"
|
||||
self._logs_filepath = f"{self._save_dir}/logs.txt"
|
||||
self._results_dir = results_dir
|
||||
|
||||
# TODO: Should the be none or SimulationParameters() and SimulationState() respectively?
|
||||
self._sim_params = None
|
||||
self._sim_state = None
|
||||
|
||||
self._sim_running = False
|
||||
|
||||
Path(self._save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._logger = spd.FileLogger("SimulationManager", self._logs_filepath)
|
||||
self._logger.set_level(spd.LogLevel.DEBUG)
|
||||
|
||||
signal.signal(signal.SIGINT, self._exit_gracefully)
|
||||
signal.signal(signal.SIGTERM, self._exit_gracefully)
|
||||
|
||||
#
|
||||
# Functions relating to the pausing and restarting of simulations
|
||||
#
|
||||
|
||||
def unfinished_simulation_present(self) -> bool:
|
||||
return os.path.isfile(self._sim_parameters_filepath) \
|
||||
and os.path.isfile(self._sim_state_filepath)
|
||||
|
||||
def continue_unfinished(self):
|
||||
assert self.unfinished_simulation_present()
|
||||
|
||||
with open(self._sim_parameters_filepath, "rb") as file:
|
||||
self._sim_params = pickle.load(file)
|
||||
with open(self._sim_state_filepath, "rb") as file:
|
||||
self._sim_state = pickle.load(file)
|
||||
|
||||
self._logger.info("Loaded saved simulation state")
|
||||
|
||||
self.start()
|
||||
|
||||
# TODO: Make sure old state is overwritten
|
||||
def _save_state(self):
|
||||
with open(self._sim_parameters_filepath, "wb") as file:
|
||||
pickle.dump(self._sim_params, file)
|
||||
with open(self._sim_state_filepath, "wb") as file:
|
||||
pickle.dump(self._sim_state, file)
|
||||
|
||||
self._logger.info("Saved simulation state")
|
||||
|
||||
def _exit_gracefully(self, *args):
|
||||
self._logger.debug("Intercepted signal SIGINT/SIGTERM")
|
||||
|
||||
self._sim_running = False
|
||||
|
||||
if (self._sim_params is not None) and (self._sim_state is not None):
|
||||
self._save_state()
|
||||
|
||||
#
|
||||
# Functions responsible for the actual simulation
|
||||
#
|
||||
|
||||
# def test_decoders(self,
|
||||
# n: int,
|
||||
# k: int,
|
||||
# decoders: typing.Sequence[typing.Any],
|
||||
# SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
|
||||
# target_frame_errors: int = 100):
|
||||
# """Calculate the Bit Error Rate (BER) for a number of given decoders for a number of SNRs.
|
||||
#
|
||||
# This function assumes the all-zeros assumption holds. Progress is printed to stdout.
|
||||
#
|
||||
# :param n: Length of a codeword of the used code
|
||||
# :param k: Length of a dataword of the used code
|
||||
# :param decoders: List of decoder objects to be tested
|
||||
# :param SNRs: List of SNRs for which the BER should be calculated
|
||||
# :param target_frame_errors: Number of frame errors after which to stop the simulation
|
||||
# :return: Tuple of the form (SNRs, [BERs_1, BERs_2, ...]) where SNR and BERs_x are numpy arrays
|
||||
# """
|
||||
# # TODO
|
||||
#
|
||||
# # Save simulation
|
||||
# self._sim_parameters = SimulationMetaData(n, k, decoders, SNRs, target_frame_errors)
|
||||
# self._sim_state = SimulationState()
|
||||
#
|
||||
# self._logger.info("Initialized new simulation state")
|
||||
#
|
||||
# # Simulation
|
||||
#
|
||||
# result_BERs = []
|
||||
#
|
||||
# start_time = default_timer()
|
||||
#
|
||||
# for decoder in tqdm(decoders,
|
||||
# desc="Calculating the answer to life, the universe and everything",
|
||||
# position=0,
|
||||
# leave=False,
|
||||
# bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
|
||||
# _, BERs = self.test_decoder(n, k, decoder, SNRs, target_frame_errors)
|
||||
# result_BERs.append(BERs)
|
||||
#
|
||||
# end_time = default_timer()
|
||||
# print(f"Elapsed time: {end_time - start_time:.2f}s")
|
||||
#
|
||||
# return SNRs, result_BERs
|
||||
|
||||
# def test_decoder(self,
|
||||
# n: int,
|
||||
# k: int,
|
||||
# decoder: typing.Any,
|
||||
# SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
|
||||
# target_frame_errors: int = 100) \
|
||||
# -> typing.Tuple[np.array, np.array]:
|
||||
def start(self):
|
||||
self._sim_running = True # TODO: Move this somewhere else
|
||||
|
||||
decoder = self._sim_params.decoders[self._sim_state.current_decoder_index]
|
||||
|
||||
x = np.zeros(self._sim_params.n)
|
||||
x_bpsk = 1 - 2 * x # Map x from [0, 1]^n to [-1, 1]^n
|
||||
|
||||
BERs = []
|
||||
for SNR in tqdm(self._sim_params.SNRs[self._sim_state.current_SNRs_index:],
|
||||
desc=f"Calculating BERs for {decoder.__class__.__name__}",
|
||||
position=1,
|
||||
leave=False,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
|
||||
|
||||
pbar = tqdm(total=self._sim_params.target_frame_errors,
|
||||
desc=f"Simulating for SNR = {SNR} dB",
|
||||
position=2,
|
||||
leave=False,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]")
|
||||
|
||||
pbar.update(self._sim_state.num_frame_errors)
|
||||
|
||||
while self._sim_state.num_frame_errors < self._sim_params.target_frame_errors:
|
||||
if not self._sim_running:
|
||||
return
|
||||
|
||||
# Simulate channel
|
||||
y = noise.add_awgn(x_bpsk, SNR, self._sim_params.n, self._sim_params.k)
|
||||
|
||||
# Decode received frame
|
||||
x_hat = decoder.decode(y)
|
||||
|
||||
# Calculate statistics
|
||||
bit_errors = count_bit_errors(x, x_hat)
|
||||
self._sim_state.num_total_bits += x.size
|
||||
|
||||
if bit_errors > 0:
|
||||
self._sim_state.num_frame_errors += 1
|
||||
self._sim_state.num_bit_errors += bit_errors
|
||||
pbar.update(1)
|
||||
|
||||
# TODO: Load BERs from file as well
|
||||
BERs.append(self._sim_state.num_bit_errors / self._sim_state.num_total_bits)
|
||||
|
||||
pbar.close()
|
||||
self._sim_state.current_SNRs_index += 1
|
||||
self._sim_state.num_frame_errors = 0
|
||||
self._sim_state.num_bit_errors = 0
|
||||
self._sim_state.num_total_bits = 0
|
||||
|
||||
# return np.array(self._sim_params.SNRs), np.array(BERs)
|
||||
|
||||
|
||||
class DecoderTester:
|
||||
"""Class used to test decoders simulating BPSK modulation and an AWGN channel.
|
||||
|
||||
Allows for recovering a stopped simulation if its previous state is known.
|
||||
"""
|
||||
|
||||
def __init__(self, initial_sim_state: SimulationState = SimulationState()):
|
||||
"""Construct a DecoderTester object.
|
||||
|
||||
:param initial_sim_state: State the simulation should start from
|
||||
"""
|
||||
self._state = initial_sim_state
|
||||
|
||||
def get_state(self) -> SimulationState:
|
||||
return self._state
|
||||
|
||||
def configure(self, n: int,
|
||||
k: int,
|
||||
SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
|
||||
target_frame_errors: int = 100):
|
||||
pass
|
||||
|
||||
def start_test(self, decoders: typing.List):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
@ -1,115 +0,0 @@
|
||||
"""This file contains utility functions relating to tests and simulations of the decoders."""
|
||||
|
||||
import numpy as np
|
||||
import typing
|
||||
from tqdm import tqdm
|
||||
from timeit import default_timer
|
||||
|
||||
from utility import noise
|
||||
|
||||
|
||||
def count_bit_errors(d: np.array, d_hat: np.array) -> int:
|
||||
"""Count the number of wrong bits in a decoded codeword.
|
||||
|
||||
:param d: Originally sent data
|
||||
:param d_hat: Received data
|
||||
:return: Number of bit errors
|
||||
"""
|
||||
return np.sum(d != d_hat)
|
||||
|
||||
|
||||
def test_decoder(n: int,
|
||||
k: int,
|
||||
decoder: typing.Any,
|
||||
SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
|
||||
target_frame_errors: int = 100) \
|
||||
-> typing.Tuple[np.array, np.array]:
|
||||
"""Calculate the Bit Error Rate (BER) for a given decoder for a number of SNRs.
|
||||
|
||||
This function assumes the all-zeros assumption holds. Progress is printed to stdout.
|
||||
|
||||
:param n: Length of a codeword of the used code
|
||||
:param k: Length of a dataword of the used code
|
||||
:param decoder: Instance of the decoder to be tested
|
||||
:param SNRs: List of SNRs for which the BER should be calculated
|
||||
:param target_frame_errors: Number of frame errors after which to stop the simulation
|
||||
:param N_max: Maximum number of iterations to perform for each SNR
|
||||
:return: Tuple of numpy arrays of the form (SNRs, BERs)
|
||||
"""
|
||||
|
||||
x = np.zeros(n)
|
||||
x_bpsk = 1 - 2 * x # Map x from [0, 1]^n to [-1, 1]^n
|
||||
|
||||
BERs = []
|
||||
for SNR in tqdm(SNRs,
|
||||
desc=f"Calculating BERs for {decoder.__class__.__name__}",
|
||||
position=1,
|
||||
leave=False,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
|
||||
|
||||
total_bit_errors = 0
|
||||
total_bits = 0
|
||||
total_frame_errors = 0
|
||||
|
||||
pbar = tqdm(total=target_frame_errors,
|
||||
desc=f"Simulating for SNR = {SNR} dB",
|
||||
position=2,
|
||||
leave=False,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]")
|
||||
|
||||
while total_frame_errors < target_frame_errors:
|
||||
# Simulate channel
|
||||
y = noise.add_awgn(x_bpsk, SNR, n, k)
|
||||
|
||||
# Decode received frame
|
||||
x_hat = decoder.decode(y)
|
||||
|
||||
# Calculate statistics
|
||||
bit_errors = count_bit_errors(x, x_hat)
|
||||
total_bits += x.size
|
||||
|
||||
if bit_errors > 0:
|
||||
total_frame_errors += 1
|
||||
total_bit_errors += bit_errors
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
BERs.append(total_bit_errors / total_bits)
|
||||
|
||||
return np.array(SNRs), np.array(BERs)
|
||||
|
||||
|
||||
def test_decoders(n: int,
|
||||
k: int,
|
||||
decoders: typing.List,
|
||||
SNRs: typing.Sequence[float] = np.linspace(1, 7, 7),
|
||||
target_frame_errors: int = 100) \
|
||||
-> typing.Tuple[np.array, np.array]:
|
||||
"""Calculate the Bit Error Rate (BER) for a number of given decoders for a number of SNRs.
|
||||
|
||||
This function assumes the all-zeros assumption holds. Progress is printed to stdout.
|
||||
|
||||
:param n: Length of a codeword of the used code
|
||||
:param k: Length of a dataword of the used code
|
||||
:param decoders: List of decoder objects to be tested
|
||||
:param SNRs: List of SNRs for which the BER should be calculated
|
||||
:param target_frame_errors: Number of frame errors after which to stop the simulation
|
||||
:return: Tuple of the form (SNRs, [BERs_1, BERs_2, ...]) where SNR and BERs_x are numpy arrays
|
||||
"""
|
||||
result_BERs = []
|
||||
|
||||
start_time = default_timer()
|
||||
|
||||
for decoder in tqdm(decoders,
|
||||
desc="Calculating the answer to life, the universe and everything",
|
||||
position=0,
|
||||
leave=False,
|
||||
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"):
|
||||
_, BERs = test_decoder(n, k, decoder, SNRs, target_frame_errors)
|
||||
result_BERs.append(BERs)
|
||||
|
||||
end_time = default_timer()
|
||||
print(f"Elapsed time: {end_time - start_time:.2f}s")
|
||||
|
||||
return SNRs, result_BERs
|
||||
Loading…
Reference in New Issue
Block a user