from itertools import chain import seaborn as sns import matplotlib.pyplot as plt import pandas as pd import numpy as np from matplotlib.ticker import FormatStrFormatter def format_yticks(previous_yticks): result = [] for tick in previous_yticks: result.append(f"{float(tick.get_text()):.2e}") return result def main(): sns.set_theme() # titles = [ # "$n=7$, $m=4$", # "$n=31$, $m=20$", # "$n=31$, $m=5$", # "$n=96$, $m=48$", # "$n=204$, $m=102$", # "$n=408$, $m=204$", # ] # # filenames = [ # "sim_results/2d_dec_fails_w_log_k_lin_bch_7_4alist.csv", # "sim_results/2d_dec_fails_w_log_k_lin_bch_31_11alist.csv", # "sim_results/2d_dec_fails_w_log_k_lin_bch_31_26alist.csv", # "sim_results/2d_dec_fails_w_log_k_lin_963965alist.csv", # "sim_results/2d_dec_fails_w_log_k_lin_2043486alist.csv", # "sim_results/2d_dec_fails_w_log_k_lin_40833844alist.csv", # ] titles = [ "$n=7$, $m=4$", "$n=31$, $m=20$", "$n=31$, $m=5$", "$n=96$, $m=48$" ] filenames = [ "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_bch_7_4alist.csv", "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_bch_31_11alist.csv", "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_bch_31_26alist.csv", "sim_results/2d_dec_fails_w_log_k_lin_zoomed_in_963965alist.csv", ] fig, axes = plt.subplots(2, 2, squeeze=False) fig.suptitle("SNR = 3dB") fig.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.3, hspace=0.4) axes = list(chain.from_iterable(axes))[ :len(filenames)] # Flatten the 2d axes array for axis, title, filename in zip(axes, titles, filenames): df = pd.read_csv(filename, index_col=0) sns.heatmap(ax=axis, data=df) axis.set_yticklabels(format_yticks(axis.get_yticklabels())) axis.set_title(title) plt.show() if __name__ == "__main__": main()