diff --git a/sw/test/test_visualization.py b/sw/test/test_visualization.py index a32d256..a668b77 100644 --- a/sw/test/test_visualization.py +++ b/sw/test/test_visualization.py @@ -6,16 +6,16 @@ class NumRowsTestCase(unittest.TestCase): def test_get_num_rows(self): """Test case for number of row calculation.""" - num_rows1 = visualization._get_num_rows(4, 2) + num_rows1 = visualization._get_num_rows(num_graphs=4, num_cols=3) expected_rows1 = 2 - num_rows2 = visualization._get_num_rows(5, 2) + num_rows2 = visualization._get_num_rows(num_graphs=5, num_cols=2) expected_rows2 = 3 - num_rows3 = visualization._get_num_rows(4, 4) + num_rows3 = visualization._get_num_rows(num_graphs=4, num_cols=4) expected_rows3 = 1 - num_rows4 = visualization._get_num_rows(4, 5) + num_rows4 = visualization._get_num_rows(num_graphs=4, num_cols=5) expected_rows4 = 1 self.assertEqual(num_rows1, expected_rows1) diff --git a/sw/utility/visualization.py b/sw/utility/visualization.py index 30f2968..c63e9fa 100644 --- a/sw/utility/visualization.py +++ b/sw/utility/visualization.py @@ -21,6 +21,7 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int: # TODO: Set proper line labels # TODO: Set proper axis titles # TODO: Should unnamed columns be dropped by this function or by the caller? +# TODO: Handle number of graphs not nicely fitting into rows and columns def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.figure: """This function creates a matplotlib figure containing a number of BER curves. @@ -30,9 +31,9 @@ def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.f :return: Matplotlib figure """ num_graphs = len(data) - num_rows = _get_num_rows(num_cols, num_cols) + num_rows = _get_num_rows(num_graphs, num_cols) - fig, axes = plt.subplots(num_rows, num_cols) + fig, axes = plt.subplots(num_rows, num_cols, squeeze=False) fig.suptitle("Bit-Error-Rates of various decoders for different codes") axes = list(chain.from_iterable(axes))[:num_graphs] # Flatten the 2d axes array