import warnings from typing import Sequence import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as pt from scipy.sparse import csc_matrix from quits.decoder import spacetime from quits.decoder import detector_error_model_to_matrix from quits.qldpc_code import BbCode from quits import ErrorModel, CircuitBuildOptions def build_bb_circuit(N: int, num_rounds: int, p: float): # fmt: off if N == 72: code = BbCode(l=6, m=6, A_x_pows=[3], A_y_pows=[1, 2], B_x_pows=[1, 2], B_y_pows=[3]) elif N == 90: code = BbCode(l=15, m=3, A_x_pows=[9], A_y_pows=[1, 2], B_x_pows=[2, 7], B_y_pows=[0]) elif N == 108: code = BbCode(l=9, m=6, A_x_pows=[3], A_y_pows=[1, 2], B_x_pows=[1, 2], B_y_pows=[3]) elif N == 144: code = BbCode(l=12, m=6, A_x_pows=[3], A_y_pows=[1, 2], B_x_pows=[1, 2], B_y_pows=[3]) elif N == 288: code = BbCode(l=12, m=12, A_x_pows=[3], A_y_pows=[2, 7], B_x_pows=[1, 2], B_y_pows=[3]) elif N == 360: code = BbCode(l=30, m=6, A_x_pows=[9], A_y_pows=[1, 2], B_x_pows=[25, 26], B_y_pows=[3]) elif N == 756: code = BbCode(l=21, m=18, A_x_pows=[3], A_y_pows=[10, 17], B_x_pows=[3, 19], B_y_pows=[5]) else: assert False, "Unsupported code size" # fmt: on circuit = code.build_circuit( error_model=ErrorModel(p, p, p, p), num_rounds=num_rounds, basis="Z", circuit_build_options=CircuitBuildOptions(), seed=1, ) return code, circuit def compute_num_windows(num_rounds: int, W: int, F: int): """ This was extracted from the function `sliding_window_circuit_mem()` of `quits.decoder`. """ if 2 + num_rounds - W >= 0: # num_cor_rounds = num of windows before the last window num_cor_rounds = (2 + num_rounds - W) // F # we can slide one more window if the remaining rounds > W if (2 + num_rounds - W) % F != 0: num_cor_rounds += 1 else: num_cor_rounds = 0 warnings.warn( "Window size larger than the syndrome extraction rounds: Doing" " whole history correction" ) return num_cor_rounds + 1 def get_overlap_info( col_start_indices: Sequence, W: int, F: int, m: int, win_check_set: Sequence ): def i_B(k: int): return col_start_indices[k] def i_E(k: int): return i_B(k) + win_check_set[k].shape[1] def j_B(k: int): return F * (k) * m def j_E(k: int): return (F * k + W) * m num_windows = len(win_check_set) overlap_begin_positions = [] for k in range(num_windows - 1): overlap_begin_positions.append((j_B(k + 1) - j_B(k), i_B(k + 1) - i_B(k))) overlap_end_positions = [] for k in range(1, num_windows): overlap_end_positions.append((j_E(k - 1) - j_B(k), i_E(k - 1) - i_B(k))) return overlap_begin_positions, overlap_end_positions def reconstruct_window_start_col_indices(win_observable_set: Sequence): """ This function effectively just reconstructs the `col_min` values of each window, from the `spacetime()` function of `quits.decoder`. """ num_windows = len(win_observable_set) col_mins = [0] for k in range(num_windows - 1): col_mins.append(col_mins[-1] + win_observable_set[k].shape[1]) return col_mins num_rounds = 12 N = 72 p = 0.005 W = 5 F = 3 # # Get detector error matrix and split it into windows # code, circuit = build_bb_circuit(N, num_rounds, p) model = circuit.detector_error_model(decompose_errors=False) check_matrix, observable_matrix, priors = detector_error_model_to_matrix(model) num_windows = compute_num_windows(num_rounds, W, F) win_check_set, win_observable_set, win_priors_set, win_update = spacetime( circuit, code.hz, W, F, num_windows - 1 ) col_start_indices = reconstruct_window_start_col_indices(win_observable_set) # # Paint rectangles # custom_colors = [ (162/255, 34/255, 35/255), (223/255, 155/255, 27/255), (70/255, 100/255, 170/255), (163/255, 16/255, 124/255), ] fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.spy(check_matrix.toarray()) colors = [custom_colors[i % len(custom_colors)] for i in range(num_windows)] m = code.hz.shape[0] # for win_idx in range(num_windows): # col_start_idx = col_start_indices[win_idx] # row_start_idx = win_idx * F * m # # ax.add_patch( # pt.Rectangle( # (col_start_idx, row_start_idx), # win_check_set[win_idx].shape[1], # win_check_set[win_idx].shape[0], # fc="none", # ec=colors[win_idx], # ) # ) # overlap_begin_positions, overlap_end_positions = get_overlap_info( # col_start_indices, W, F, m, win_check_set # ) # for k in range(len(win_check_set) - 1): # ax.add_patch( # pt.Rectangle( # ( # overlap_begin_positions[k][1] + col_start_indices[k], # overlap_begin_positions[k][0] + F * k * m, # ), # win_check_set[k].shape[1] - overlap_begin_positions[k][1], # win_check_set[k].shape[0] - overlap_begin_positions[k][0], # fc=colors[k], # ec=colors[k], # alpha=0.3, # ) # ) ax.set_xticks([]) ax.set_yticks([]) fig.savefig('72_bb_dem_no_windows.pdf', bbox_inches='tight') plt.show()