Add sliding-window decoding slide

This commit is contained in:
2026-04-15 18:40:29 +02:00
parent e940e7ab9f
commit 4250f2a903
4 changed files with 834 additions and 769 deletions

View File

@@ -0,0 +1,195 @@
import warnings
from typing import Sequence
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as pt
from scipy.sparse import csc_matrix
from quits.decoder import spacetime
from quits.decoder import detector_error_model_to_matrix
from quits.qldpc_code import BbCode
from quits import ErrorModel, CircuitBuildOptions
def build_bb_circuit(N: int, num_rounds: int, p: float):
# fmt: off
if N == 72:
code = BbCode(l=6, m=6, A_x_pows=[3], A_y_pows=[1, 2], B_x_pows=[1, 2], B_y_pows=[3])
elif N == 90:
code = BbCode(l=15, m=3, A_x_pows=[9], A_y_pows=[1, 2], B_x_pows=[2, 7], B_y_pows=[0])
elif N == 108:
code = BbCode(l=9, m=6, A_x_pows=[3], A_y_pows=[1, 2], B_x_pows=[1, 2], B_y_pows=[3])
elif N == 144:
code = BbCode(l=12, m=6, A_x_pows=[3], A_y_pows=[1, 2], B_x_pows=[1, 2], B_y_pows=[3])
elif N == 288:
code = BbCode(l=12, m=12, A_x_pows=[3], A_y_pows=[2, 7], B_x_pows=[1, 2], B_y_pows=[3])
elif N == 360:
code = BbCode(l=30, m=6, A_x_pows=[9], A_y_pows=[1, 2], B_x_pows=[25, 26], B_y_pows=[3])
elif N == 756:
code = BbCode(l=21, m=18, A_x_pows=[3], A_y_pows=[10, 17], B_x_pows=[3, 19], B_y_pows=[5])
else:
assert False, "Unsupported code size"
# fmt: on
circuit = code.build_circuit(
error_model=ErrorModel(p, p, p, p),
num_rounds=num_rounds,
basis="Z",
circuit_build_options=CircuitBuildOptions(),
seed=1,
)
return code, circuit
def compute_num_windows(num_rounds: int, W: int, F: int):
"""
This was extracted from the function `sliding_window_circuit_mem()` of
`quits.decoder`.
"""
if 2 + num_rounds - W >= 0:
# num_cor_rounds = num of windows before the last window
num_cor_rounds = (2 + num_rounds - W) // F
# we can slide one more window if the remaining rounds > W
if (2 + num_rounds - W) % F != 0:
num_cor_rounds += 1
else:
num_cor_rounds = 0
warnings.warn(
"Window size larger than the syndrome extraction rounds: Doing"
" whole history correction"
)
return num_cor_rounds + 1
def get_overlap_info(
col_start_indices: Sequence, W: int, F: int, m: int, win_check_set: Sequence
):
def i_B(k: int):
return col_start_indices[k]
def i_E(k: int):
return i_B(k) + win_check_set[k].shape[1]
def j_B(k: int):
return F * (k) * m
def j_E(k: int):
return (F * k + W) * m
num_windows = len(win_check_set)
overlap_begin_positions = []
for k in range(num_windows - 1):
overlap_begin_positions.append((j_B(k + 1) - j_B(k), i_B(k + 1) - i_B(k)))
overlap_end_positions = []
for k in range(1, num_windows):
overlap_end_positions.append((j_E(k - 1) - j_B(k), i_E(k - 1) - i_B(k)))
return overlap_begin_positions, overlap_end_positions
def reconstruct_window_start_col_indices(win_observable_set: Sequence):
"""
This function effectively just reconstructs the `col_min` values of each
window, from the `spacetime()` function of `quits.decoder`.
"""
num_windows = len(win_observable_set)
col_mins = [0]
for k in range(num_windows - 1):
col_mins.append(col_mins[-1] + win_observable_set[k].shape[1])
return col_mins
num_rounds = 12
N = 72
p = 0.005
W = 5
F = 3
#
# Get detector error matrix and split it into windows
#
code, circuit = build_bb_circuit(N, num_rounds, p)
model = circuit.detector_error_model(decompose_errors=False)
check_matrix, observable_matrix, priors = detector_error_model_to_matrix(model)
num_windows = compute_num_windows(num_rounds, W, F)
win_check_set, win_observable_set, win_priors_set, win_update = spacetime(
circuit, code.hz, W, F, num_windows - 1
)
col_start_indices = reconstruct_window_start_col_indices(win_observable_set)
#
# Paint rectangles
#
custom_colors = [
(162/255, 34/255, 35/255),
(223/255, 155/255, 27/255),
(70/255, 100/255, 170/255),
(163/255, 16/255, 124/255),
]
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.spy(check_matrix.toarray())
colors = [custom_colors[i % len(custom_colors)] for i in range(num_windows)]
m = code.hz.shape[0]
# for win_idx in range(num_windows):
# col_start_idx = col_start_indices[win_idx]
# row_start_idx = win_idx * F * m
#
# ax.add_patch(
# pt.Rectangle(
# (col_start_idx, row_start_idx),
# win_check_set[win_idx].shape[1],
# win_check_set[win_idx].shape[0],
# fc="none",
# ec=colors[win_idx],
# )
# )
# overlap_begin_positions, overlap_end_positions = get_overlap_info(
# col_start_indices, W, F, m, win_check_set
# )
# for k in range(len(win_check_set) - 1):
# ax.add_patch(
# pt.Rectangle(
# (
# overlap_begin_positions[k][1] + col_start_indices[k],
# overlap_begin_positions[k][0] + F * k * m,
# ),
# win_check_set[k].shape[1] - overlap_begin_positions[k][1],
# win_check_set[k].shape[0] - overlap_begin_positions[k][0],
# fc=colors[k],
# ec=colors[k],
# alpha=0.3,
# )
# )
ax.set_xticks([])
ax.set_yticks([])
fig.savefig('72_bb_dem_no_windows.pdf', bbox_inches='tight')
plt.show()

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.