Finished initial (non-working) implementation of proximal decoder

This commit is contained in:
2022-11-04 21:07:35 +01:00
parent accc318a77
commit 6444914296
8 changed files with 301 additions and 0 deletions

0
sw/test/__init__.py Normal file
View File

55
sw/test/test_proximal.py Normal file
View File

@@ -0,0 +1,55 @@
import unittest
import numpy as np
from decoders import proximal
class CheckParityTestCase(unittest.TestCase):
"""Test case for the check_parity function."""
def test_check_parity(self):
# Hamming(7,4) code
G = np.array([[1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 1, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[1, 1, 0, 1, 0, 0, 1]])
H = np.array([[1, 0, 1, 0, 1, 0, 1],
[0, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1, 1]])
decoder = proximal.ProximalDecoder(H)
d1 = np.array([0, 1, 0, 1])
c1 = np.dot(np.transpose(G), d1) % 2
d2 = np.array([0, 0, 0, 0])
c2 = np.dot(np.transpose(G), d2) % 2
d3 = np.array([1, 1, 1, 1])
c3 = np.dot(np.transpose(G), d3) % 2
invalid_codeword = np.array([0, 1, 1, 0, 1, 1, 1])
self.assertEqual(decoder._check_parity(c1), True)
self.assertEqual(decoder._check_parity(c2), True)
self.assertEqual(decoder._check_parity(c3), True)
self.assertEqual(decoder._check_parity(invalid_codeword), False)
class GradientTestCase(unittest.TestCase):
"""Test case for the calculation of the gradient of the code-constraint-polynomial"""
def test_grad_h(self):
H = np.array([[1, 0, 1],
[0, 1, 0]])
x = np.array([2, 3, 4])
decoder = proximal.ProximalDecoder(H)
grad = decoder._grad_h(x)
expected = 4 * (x**2 - 1)*x + 2 / x * np.array([0, 2, 0])
print(f"expected: {expected}")
self.assertEqual(np.array_equal(grad, expected), True)
if __name__ == "__main__":
unittest.main()

40
sw/test/test_utility.py Normal file
View File

@@ -0,0 +1,40 @@
import unittest
import numpy as np
from decoders import utility
class CountBitErrorsTestCase(unittest.TestCase):
"""Test case for bit error counting."""
def test_count_bit_errors(self):
d1 = np.array([0, 0, 0, 0])
y_hat1 = np.array([0, 1, 0, 1])
d2 = np.array([0, 0, 0, 0])
y_hat2 = np.array([0, 0, 0, 0])
d3 = np.array([0, 0, 0, 0])
y_hat3 = np.array([1, 1, 1, 1])
self.assertEqual(utility.count_bit_errors(d1, y_hat1), 2)
self.assertEqual(utility.count_bit_errors(d2, y_hat2), 0)
self.assertEqual(utility.count_bit_errors(d3, y_hat3), 4)
class NoiseAmpFromSNRTestCase(unittest.TestCase):
"""Test case for noise amplitude calculation"""
def test_get_noise_amp_from_SNR(self):
SNR1 = 0
SNR2 = 6
SNR3 = 20
SNR4 = -20
SNR5 = 60
self.assertEqual(utility._get_noise_amp_from_SNR(SNR1, signal_amp=1), 1)
self.assertAlmostEqual(utility._get_noise_amp_from_SNR(SNR2, signal_amp=1), 0.5, places=2)
self.assertEqual(utility._get_noise_amp_from_SNR(SNR3, signal_amp=1), 0.1)
self.assertEqual(utility._get_noise_amp_from_SNR(SNR4, signal_amp=1), 10)
self.assertEqual(utility._get_noise_amp_from_SNR(SNR5, signal_amp=2), 0.002)
if __name__ == '__main__':
unittest.main()