69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
import unittest
|
|
import numpy as np
|
|
from decoders import proximal
|
|
|
|
|
|
class CheckParityTestCase(unittest.TestCase):
|
|
"""Test case for the check_parity function."""
|
|
def test_check_parity(self):
|
|
# Hamming(7,4) code
|
|
G = np.array([[1, 1, 1, 0, 0, 0, 0],
|
|
[1, 0, 0, 1, 1, 0, 0],
|
|
[0, 1, 0, 1, 0, 1, 0],
|
|
[1, 1, 0, 1, 0, 0, 1]])
|
|
H = np.array([[1, 0, 1, 0, 1, 0, 1],
|
|
[0, 1, 1, 0, 0, 1, 1],
|
|
[0, 0, 0, 1, 1, 1, 1]])
|
|
R = np.array([[0, 0, 1, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 1, 0, 0],
|
|
[0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 1]])
|
|
|
|
decoder = proximal.ProximalDecoder(H, R)
|
|
|
|
d1 = np.array([0, 1, 0, 1])
|
|
c1 = np.dot(np.transpose(G), d1) % 2
|
|
|
|
d2 = np.array([0, 0, 0, 0])
|
|
c2 = np.dot(np.transpose(G), d2) % 2
|
|
|
|
d3 = np.array([1, 1, 1, 1])
|
|
c3 = np.dot(np.transpose(G), d3) % 2
|
|
|
|
invalid_codeword = np.array([0, 1, 1, 0, 1, 1, 1])
|
|
|
|
self.assertEqual(decoder._check_parity(c1), True)
|
|
self.assertEqual(decoder._check_parity(c2), True)
|
|
self.assertEqual(decoder._check_parity(c3), True)
|
|
self.assertEqual(decoder._check_parity(invalid_codeword), False)
|
|
|
|
|
|
class GradientTestCase(unittest.TestCase):
|
|
"""Test case for the calculation of the gradient of the code-constraint-polynomial."""
|
|
def test_grad_h(self):
|
|
"""Test the gradient of the code-constraint polynomial."""
|
|
# Hamming(7,4) code
|
|
G = np.array([[1, 1, 1, 0, 0, 0, 0],
|
|
[1, 0, 0, 1, 1, 0, 0],
|
|
[0, 1, 0, 1, 0, 1, 0],
|
|
[1, 1, 0, 1, 0, 0, 1]])
|
|
H = np.array([[1, 0, 1, 0, 1, 0, 1],
|
|
[0, 1, 1, 0, 0, 1, 1],
|
|
[0, 0, 0, 1, 1, 1, 1]])
|
|
R = np.array([[0, 0, 1, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 1, 0, 0],
|
|
[0, 0, 0, 0, 0, 1, 0],
|
|
[0, 0, 0, 0, 0, 0, 1]])
|
|
|
|
x = np.array([1, 2, -1, -2, 2, 1, -1]) # Some randomly chosen vector
|
|
expected_grad_h = np.array([4, 26, -8, -36, 38, 28, -32]) # Manually calculated result
|
|
|
|
decoder = proximal.ProximalDecoder(H, R)
|
|
grad_h = decoder._grad_h(x)
|
|
|
|
self.assertEqual(np.array_equal(grad_h, expected_grad_h), True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|