ba-thesis/sw/test/test_proximal.py

94 lines
3.2 KiB
Python

import unittest
import numpy as np
from decoders import proximal
class CheckParityTestCase(unittest.TestCase):
"""Test case for the check_parity function."""
def test_check_parity(self):
# Hamming(7,4) code
G = np.array([[1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 1, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[1, 1, 0, 1, 0, 0, 1]])
H = np.array([[1, 0, 1, 0, 1, 0, 1],
[0, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1, 1]])
R = np.array([[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
decoder = proximal.ProximalDecoder(H, R)
d1 = np.array([0, 1, 0, 1])
c1 = np.dot(np.transpose(G), d1) % 2
d2 = np.array([0, 0, 0, 0])
c2 = np.dot(np.transpose(G), d2) % 2
d3 = np.array([1, 1, 1, 1])
c3 = np.dot(np.transpose(G), d3) % 2
invalid_codeword = np.array([0, 1, 1, 0, 1, 1, 1])
self.assertEqual(decoder._check_parity(c1), True)
self.assertEqual(decoder._check_parity(c2), True)
self.assertEqual(decoder._check_parity(c3), True)
self.assertEqual(decoder._check_parity(invalid_codeword), False)
class GradientTestCase(unittest.TestCase):
"""Test case for the calculation of the gradient of the code-constraint-polynomial."""
def test_grad_h(self):
"""Test the gradient of the code-constraint polynomial."""
H = np.array([[1, 0, 0],
[0, 1, 0]])
x = np.array([1, 2, 2])
R = np.array([0])
decoder = proximal.ProximalDecoder(H, R)
grad = decoder._grad_h(x)
expected = 4 * (x**2 - 1)*x + 2 / x * np.array([0, 2, 0])
self.assertEqual(np.array_equal(grad, expected), True)
def test_gen_A_B(self):
"""Test the generation of the A and B sets used for the gradient calculation."""
# Hamming(7,4) code
G = np.array([[1, 1, 1, 0, 0, 0, 0],
[1, 0, 0, 1, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0],
[1, 1, 0, 1, 0, 0, 1]])
H = np.array([[1, 0, 1, 0, 1, 0, 1],
[0, 1, 1, 0, 0, 1, 1],
[0, 0, 0, 1, 1, 1, 1]])
R = np.array([[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
decoder = proximal.ProximalDecoder(H, R)
expected_A = [np.array([0, 2, 4, 6]),
np.array([1, 2, 5, 6]),
np.array([3, 4, 5, 6])]
expected_B = [np.array([0]),
np.array([1]),
np.array([0, 1]),
np.array([2]),
np.array([0, 2]),
np.array([1, 2]),
np.array([0, 1, 2])]
for A_i, expected_A_i in zip(decoder._A, expected_A):
self.assertEqual(np.array_equal(A_i, expected_A_i), True)
for B_k, expected_B_k in zip(decoder._B, expected_B):
self.assertEqual(np.array_equal(B_k, expected_B_k), True)
if __name__ == "__main__":
unittest.main()