Properly implemented visualization._get_num_rows() and added unit tests
This commit is contained in:
parent
b87129df2a
commit
6105eef4c1
24
sw/test/test_visualization.py
Normal file
24
sw/test/test_visualization.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import unittest
|
||||||
|
from utility import visualization
|
||||||
|
|
||||||
|
|
||||||
|
class NumRowsTestCase(unittest.TestCase):
|
||||||
|
def test_get_num_rows(self):
|
||||||
|
"""Test case for number of row calculation."""
|
||||||
|
|
||||||
|
num_rows1 = visualization._get_num_rows(4, 2)
|
||||||
|
expected_rows1 = 2
|
||||||
|
|
||||||
|
num_rows2 = visualization._get_num_rows(5, 2)
|
||||||
|
expected_rows2 = 3
|
||||||
|
|
||||||
|
num_rows3 = visualization._get_num_rows(4, 4)
|
||||||
|
expected_rows3 = 1
|
||||||
|
|
||||||
|
num_rows4 = visualization._get_num_rows(4, 5)
|
||||||
|
expected_rows4 = 1
|
||||||
|
|
||||||
|
self.assertEqual(num_rows1, expected_rows1)
|
||||||
|
self.assertEqual(num_rows2, expected_rows2)
|
||||||
|
self.assertEqual(num_rows3, expected_rows3)
|
||||||
|
self.assertEqual(num_rows4, expected_rows4)
|
||||||
@ -3,6 +3,7 @@ import matplotlib.pyplot as plt
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import typing
|
import typing
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
||||||
@ -13,7 +14,7 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
|||||||
:param num_cols: Number of columns
|
:param num_cols: Number of columns
|
||||||
:return: Number of rows
|
:return: Number of rows
|
||||||
"""
|
"""
|
||||||
return num_graphs // num_cols + 1
|
return math.ceil(num_graphs / num_cols)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Calculate fig size in relation to the number of rows and columns
|
# TODO: Calculate fig size in relation to the number of rows and columns
|
||||||
@ -23,8 +24,8 @@ def _get_num_rows(num_graphs: int, num_cols: int) -> int:
|
|||||||
def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.figure:
|
def show_BER_curves(data: typing.List[pd.DataFrame], num_cols: int = 3) -> plt.figure:
|
||||||
"""This function creates a matplotlib figure containing a number of BER curves.
|
"""This function creates a matplotlib figure containing a number of BER curves.
|
||||||
|
|
||||||
:param data: List of pandas DataFrames containing the data to be plotted. Each element in the list is plotted in
|
:param data: List of pandas DataFrames containing the data to be plotted. Each dataframe in the list is plotted
|
||||||
a new graph. Each dataframe is assumed to contain a column named "SNR" which is used as the x-axis
|
in a new graph. Each dataframe is assumed to contain a column named "SNR" which is used as the x-axis
|
||||||
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
|
:param num_cols: Number of columns in which the graphs should be arranged in the resulting figure
|
||||||
:return: Matplotlib figure
|
:return: Matplotlib figure
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user