diff --git a/sw/decoders/proximal.py b/sw/decoders/proximal.py index 216f817..d476ee6 100644 --- a/sw/decoders/proximal.py +++ b/sw/decoders/proximal.py @@ -68,25 +68,26 @@ class ProximalDecoder: Detailed in 3.2, p. 3 (Equation (15)). :param x: - :return: + :return: x clipped to [-eta, eta]^n """ return np.clip(x, -self._eta, self._eta) def _check_parity(self, y_hat: np.array) -> bool: """Perform a parity check for a given codeword. - :param y_hat: codeword to be checked + :param y_hat: codeword to be checked (element of [-1, 1]^n) :return: True if the parity check passes, i.e. the codeword is valid. False otherwise """ + y_hat_binary = (y_hat == 1) * 1 # Map the codeword from [-1, 1]^n to [0, 1]^n syndrome = np.dot(self._H, y_hat) % 2 return not np.any(syndrome) def decode(self, y: np.array) -> np.array: """Decode a received signal. The algorithm is detailed in 3.2, p.3. - This function assumes an AWGN channel. + This function assumes a BPSK-like modulated signal ([-1, 1]^n instead of [0, 1]^n) and an AWGN channel. - :param y: Vector of received values + :param y: Vector of received values. (y = x + n, where 'x' is element of [-1, 1]^m and 'n' is noise) :return: Most probably sent symbol """ s = 0 diff --git a/sw/decoders/utility.py b/sw/decoders/utility.py index 60986d3..9ea4fc3 100644 --- a/sw/decoders/utility.py +++ b/sw/decoders/utility.py @@ -45,18 +45,21 @@ def count_bit_errors(d: np.array, d_hat: np.array) -> int: def test_decoder(decoder: typing.Any, c: np.array, SNRs: typing.Sequence[float] = np.linspace(1, 4, 7), - N=10000) \ + target_bit_errors=100, + N_max=10000) \ -> typing.Tuple[np.array, np.array]: """Calculate the Bit Error Rate (BER) for a given decoder for a number of SNRs. This function prints its progress to stdout. :param decoder: Instance of the decoder to be tested - :param c: Codeword whose transmission is to be simulated + :param c: Codeword whose transmission is to be simulated (element of [0, 1]^n) :param SNRs: List of SNRs for which the BER should be calculated - :param N: Number of iterations to perform for each SNR + :param target_bit_errors: Number of bit errors after which to stop the simulation + :param N_max: Maximum number of iterations to perform for each SNR :return: Tuple of numpy arrays of the form (SNRs, BERs) """ + x = c * 2 - 1 # Map the codeword from [0, 1]^n to [-1, 1]^n BERs = [] for SNR in tqdm(SNRs, desc="Calculating Bit-Error-Rates", position=0, @@ -64,19 +67,22 @@ def test_decoder(decoder: typing.Any, total_bit_errors = 0 - for n in tqdm(range(N), desc=f"Simulating for SNR = {SNR} dB", + for n in tqdm(range(N_max), desc=f"Simulating for SNR = {SNR} dB", position=1, leave=False, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}"): # TODO: Is this a valid simulation? Can we just add AWGN to the codeword, ignoring and modulation and ( # e.g. matched) filtering? - y = add_awgn(c, SNR) + y = add_awgn(x, SNR) y_hat = decoder.decode(y) total_bit_errors += count_bit_errors(c, y_hat) - total_bits = c.size * N + if total_bit_errors >= target_bit_errors: + break + + total_bits = c.size * N_max BERs.append(total_bit_errors / total_bits) return np.array(SNRs), np.array(BERs) diff --git a/sw/main.py b/sw/main.py index d4ef3fe..0375a23 100644 --- a/sw/main.py +++ b/sw/main.py @@ -27,12 +27,13 @@ def main(): print(f"Simulating with c = {c}") decoder = proximal.ProximalDecoder(H, K=100, gamma=0.01) - SNRs, BERs = utility.test_decoder(decoder, c, SNRs=[1, 3, 20], N=1000) + SNRs, BERs = utility.test_decoder(decoder, c, SNRs=[1, 3, 5, 7], N_max=10000) data = pd.DataFrame({"SNR": SNRs, "BER": BERs}) ax = sns.lineplot(data=data, x="SNR", y="BER") - ax.set_ylim([0, 1]) + ax.set(yscale="log") + #ax.set_ylim([10e-6, 10e0]) plt.show() diff --git a/sw/test/__init__.py b/sw/test/__init__.py index e69de29..2f312a6 100644 --- a/sw/test/__init__.py +++ b/sw/test/__init__.py @@ -0,0 +1 @@ +"""This package contains unit tests.""" diff --git a/sw/test/test_utility.py b/sw/test/test_utility.py index 2c67263..2c47fc8 100644 --- a/sw/test/test_utility.py +++ b/sw/test/test_utility.py @@ -20,6 +20,7 @@ class CountBitErrorsTestCase(unittest.TestCase): self.assertEqual(utility.count_bit_errors(d3, y_hat3), 4) +# TODO: Is this correct? class NoiseAmpFromSNRTestCase(unittest.TestCase): """Test case for noise amplitude calculation.""" def test_get_noise_amp_from_SNR(self):