Fixed bugs in visualization.show_BER_curves

This commit is contained in:
Andreas Tsouchlos 2022-11-11 17:38:15 +01:00
parent 6105eef4c1
commit 0f8de32e3f
2 changed files with 7 additions and 6 deletions

View File

@ -6,16 +6,16 @@ class NumRowsTestCase(unittest.TestCase):
def test_get_num_rows(self): def test_get_num_rows(self):
"""Test case for number of row calculation.""" """Test case for number of row calculation."""
num_rows1 = visualization._get_num_rows(4, 2) num_rows1 = visualization._get_num_rows(num_graphs=4, num_cols=3)
expected_rows1 = 2 expected_rows1 = 2
num_rows2 = visualization._get_num_rows(5, 2) num_rows2 = visualization._get_num_rows(num_graphs=5, num_cols=2)
expected_rows2 = 3 expected_rows2 = 3
num_rows3 = visualization._get_num_rows(4, 4) num_rows3 = visualization._get_num_rows(num_graphs=4, num_cols=4)
expected_rows3 = 1 expected_rows3 = 1
num_rows4 = visualization._get_num_rows(4, 5) num_rows4 = visualization._get_num_rows(num_graphs=4, num_cols=5)
expected_rows4 = 1 expected_rows4 = 1
self.assertEqual(num_rows1, expected_rows1) self.assertEqual(num_rows1, expected_rows1)

View File

@ -21,6 +21,7 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int:
# TODO: Set proper line labels # TODO: Set proper line labels
# TODO: Set proper axis titles # TODO: Set proper axis titles
# TODO: Should unnamed columns be dropped by this function or by the caller? # TODO: Should unnamed columns be dropped by this function or by the caller?
# TODO: Handle number of graphs not nicely fitting into rows and columns
def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.figure: def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.figure:
"""This function creates a matplotlib figure containing a number of BER curves. """This function creates a matplotlib figure containing a number of BER curves.
@ -30,9 +31,9 @@ def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.f
:return: Matplotlib figure :return: Matplotlib figure
""" """
num_graphs = len(data) num_graphs = len(data)
num_rows = _get_num_rows(num_cols, num_cols) num_rows = _get_num_rows(num_graphs, num_cols)
fig, axes = plt.subplots(num_rows, num_cols) fig, axes = plt.subplots(num_rows, num_cols, squeeze=False)
fig.suptitle("Bit-Error-Rates of various decoders for different codes") fig.suptitle("Bit-Error-Rates of various decoders for different codes")
axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array