Skip to content

Commit 0ba2dbc

Browse files
authored
Fix inconsistency between qubits measurements in readout circuits and real measurement circuits (#7801)
There is an inconsistency on qubits measurements in readout circuits and measurement circuits. The measurement circuits always measure the full circuit set of qubits, while the qubits being measured in readout/calibration circuits are based only on the qubits present in the Pauli strings. In this PR, I have updated the circuit generation to only measure the qubits relevant to the Pauli group being processed.
1 parent ba6bf26 commit 0ba2dbc

File tree

2 files changed

+95
-48
lines changed

2 files changed

+95
-48
lines changed

cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
import sympy
2727

2828
import cirq.contrib.shuffle_circuits.shuffle_circuits_with_readout_benchmarking as sc_readout
29-
from cirq import circuits, ops, study, work
29+
from cirq import circuits, ops, work
3030
from cirq.experiments.readout_confusion_matrix import TensoredConfusionMatrices
31-
from cirq.study import ResultDict
3231

3332
if TYPE_CHECKING:
33+
import cirq
3434
from cirq.experiments.single_qubit_readout_calibration import (
3535
SingleQubitReadoutCalibrationResult,
3636
)
@@ -276,13 +276,14 @@ def _generate_basis_change_circuits(
276276
insert_strategy: circuits.InsertStrategy,
277277
) -> list[circuits.Circuit]:
278278
"""Generates basis change circuits for each group of Pauli strings."""
279-
pauli_measurement_circuits = list[circuits.Circuit]()
279+
pauli_measurement_circuits: list[circuits.Circuit] = []
280280

281281
for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
282-
qid_list = list(sorted(input_circuit.all_qubits()))
283282
basis_change_circuits = []
284283
input_circuit_unfrozen = input_circuit.unfreeze()
285284
for pauli_strings in pauli_string_groups:
285+
# Extract qubits from Pauli strings
286+
qid_list = _extract_readout_qubits(pauli_strings)
286287
basis_change_circuit = circuits.Circuit(
287288
input_circuit_unfrozen,
288289
_pauli_strings_to_basis_change_ops(pauli_strings, qid_list),
@@ -298,31 +299,28 @@ def _generate_basis_change_circuits(
298299
def _generate_basis_change_circuits_with_sweep(
299300
normalized_circuits_to_pauli: dict[circuits.FrozenCircuit, list[list[ops.PauliString]]],
300301
insert_strategy: circuits.InsertStrategy,
301-
) -> tuple[list[circuits.Circuit], list[study.Sweepable]]:
302+
) -> tuple[list[circuits.Circuit], list[cirq.Sweepable]]:
302303
"""Generates basis change circuits for each group of Pauli strings with sweep."""
303-
parameterized_circuits = list[circuits.Circuit]()
304-
sweep_params = list[study.Sweepable]()
304+
parameterized_circuits: list[circuits.Circuit] = []
305+
sweep_params: list[cirq.Sweepable] = []
305306
for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
306-
qid_list = list(sorted(input_circuit.all_qubits()))
307-
phi_symbols = sympy.symbols(f"phi:{len(qid_list)}")
308-
theta_symbols = sympy.symbols(f"theta:{len(qid_list)}")
309-
310-
# Create phased gates and measurement operator
311-
phased_gates = [
312-
ops.PhasedXPowGate(phase_exponent=(a - 1) / 2, exponent=b)(qubit)
313-
for a, b, qubit in zip(phi_symbols, theta_symbols, qid_list)
314-
]
315-
measurement_op = ops.M(*qid_list, key="result")
316-
317-
parameterized_circuit = circuits.Circuit(
318-
input_circuit.unfreeze(), phased_gates, measurement_op, strategy=insert_strategy
319-
)
320-
sweep_param = []
321307
for pauli_strings in pauli_string_groups:
322-
sweep_param.append(_pauli_strings_to_basis_change_with_sweep(pauli_strings, qid_list))
323-
sweep_params.append(sweep_param)
324-
parameterized_circuits.append(parameterized_circuit)
325-
308+
# Extract qubits from Pauli strings
309+
qid_list = _extract_readout_qubits(pauli_strings)
310+
phi_symbols = sympy.symbols(f"phi:{len(qid_list)}")
311+
theta_symbols = sympy.symbols(f"theta:{len(qid_list)}")
312+
# Create phased gates and measurement operator
313+
phased_gates = [
314+
ops.PhasedXPowGate(phase_exponent=(a - 1) / 2, exponent=b)(qubit)
315+
for a, b, qubit in zip(phi_symbols, theta_symbols, qid_list)
316+
]
317+
measurement_op = ops.M(*qid_list, key="result")
318+
parameterized_circuit = circuits.Circuit(
319+
input_circuit.unfreeze(), phased_gates, measurement_op, strategy=insert_strategy
320+
)
321+
sweep_param = _pauli_strings_to_basis_change_with_sweep(pauli_strings, qid_list)
322+
parameterized_circuits.append(parameterized_circuit)
323+
sweep_params.append(sweep_param)
326324
return parameterized_circuits, sweep_params
327325

328326

@@ -372,9 +370,8 @@ def _build_many_one_qubits_empty_confusion_matrix(qubits_length: int) -> list[np
372370

373371

374372
def _process_pauli_measurement_results(
375-
qubits: Sequence[ops.Qid],
376373
pauli_string_groups: list[list[ops.PauliString]],
377-
circuit_results: Sequence[ResultDict] | Sequence[study.Result],
374+
circuit_results: Sequence[cirq.ResultDict] | Sequence[cirq.Result],
378375
calibration_results: dict[tuple[ops.Qid, ...], SingleQubitReadoutCalibrationResult],
379376
pauli_repetitions: int,
380377
timestamp: float,
@@ -419,7 +416,7 @@ def _process_pauli_measurement_results(
419416

420417
for pauli_str in pauli_strs:
421418
qubits_sorted = sorted(pauli_str.qubits)
422-
qubit_indices = [qubits.index(q) for q in qubits_sorted]
419+
qubit_indices = [pauli_readout_qubits.index(q) for q in qubits_sorted]
423420

424421
if disable_readout_mitigation:
425422
pauli_str_calibration_result = None
@@ -553,8 +550,7 @@ def measure_pauli_strings(
553550

554551
# Build the basis-change circuits for each Pauli string group
555552
pauli_measurement_circuits: list[circuits.Circuit] = []
556-
sweep_params: list[study.Sweepable] = []
557-
circuits_results: Sequence[ResultDict] | Sequence[Sequence[study.Result]] = []
553+
sweep_params: list[cirq.Sweepable] = []
558554
calibration_results: dict[tuple[ops.Qid, ...], SingleQubitReadoutCalibrationResult] = {}
559555

560556
benchmarking_params = sc_readout.ReadoutBenchmarkingParams(
@@ -569,13 +565,15 @@ def measure_pauli_strings(
569565
)
570566

571567
# Run benchmarking using sweep for readout calibration
572-
circuits_results, calibration_results = sc_readout.run_sweep_with_readout_benchmarking(
573-
sampler=sampler,
574-
input_circuits=pauli_measurement_circuits,
575-
sweep_params=sweep_params,
576-
parameters=benchmarking_params,
577-
rng_or_seed=rng_or_seed,
578-
qubits=[list(qubits) for qubits in qubits_list],
568+
sweep_circuits_results, calibration_results = (
569+
sc_readout.run_sweep_with_readout_benchmarking(
570+
sampler=sampler,
571+
input_circuits=pauli_measurement_circuits,
572+
sweep_params=sweep_params,
573+
parameters=benchmarking_params,
574+
rng_or_seed=rng_or_seed,
575+
qubits=[list(qubits) for qubits in qubits_list],
576+
)
579577
)
580578

581579
else:
@@ -597,22 +595,20 @@ def measure_pauli_strings(
597595
# Process the results to calculate expectation values
598596
results: list[CircuitToPauliStringsMeasurementResult] = []
599597
circuit_result_index = 0
600-
for i, (input_circuit, pauli_string_groups) in enumerate(normalized_circuits_to_pauli.items()):
601-
qubits_in_circuit = tuple(sorted(input_circuit.all_qubits()))
602-
598+
for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
603599
disable_readout_mitigation = False if num_random_bitstrings != 0 else True
604600

605-
circuits_results_for_group: Sequence[ResultDict] | Sequence[study.Result] = []
601+
circuits_results_for_group: Sequence[cirq.ResultDict] | Sequence[cirq.Result] = []
602+
results_slice = slice(circuit_result_index, circuit_result_index + len(pauli_string_groups))
606603
if use_sweep:
607-
circuits_results_for_group = cast(Sequence[Sequence[study.Result]], circuits_results)[i]
604+
circuits_results_for_group = [r[0] for r in sweep_circuits_results[results_slice]]
605+
608606
else:
609-
circuits_results_for_group = cast(Sequence[ResultDict], circuits_results)[
610-
circuit_result_index : circuit_result_index + len(pauli_string_groups)
611-
]
612-
circuit_result_index += len(pauli_string_groups)
607+
circuits_results_for_group = circuits_results[results_slice]
608+
609+
circuit_result_index += len(pauli_string_groups)
613610

614611
pauli_measurement_results = _process_pauli_measurement_results(
615-
list(qubits_in_circuit),
616612
pauli_string_groups,
617613
circuits_results_for_group,
618614
calibration_results,

cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,3 +803,54 @@ def test_group_paulis_type_mismatch() -> None:
803803
measure_pauli_strings(
804804
circuits_to_pauli, cirq.Simulator(), 300, 300, 300, np.random.default_rng()
805805
)
806+
807+
808+
@pytest.mark.parametrize("use_sweep", [False, True])
809+
def test_sampler_receives_correct_circuits(use_sweep: bool) -> None:
810+
"""Test that the sampler receives circuits with correct measurement qubits."""
811+
812+
from unittest.mock import MagicMock
813+
814+
from cirq.study.result import ResultDict
815+
816+
# Create a circuit with 5 qubits
817+
qubits = cirq.LineQubit.range(5)
818+
819+
circuit = cirq.FrozenCircuit(_create_ghz(5, qubits))
820+
821+
# Create Pauli strings that only use qubits 1, 2, 3 (indices 1,2,3)
822+
pauli_qubits = qubits[1:4] # qubits 1,2,3
823+
pauli_string: cirq.PauliString = cirq.PauliString({q: cirq.X for q in pauli_qubits})
824+
825+
circuits_to_pauli: dict[cirq.FrozenCircuit, list[cirq.PauliString]] = {}
826+
circuits_to_pauli[circuit] = [pauli_string]
827+
828+
# Mock the sampler
829+
mock_sampler = MagicMock()
830+
831+
# Configure the mock sampler to return a valid ResultDict
832+
mock_sampler.run.return_value = ResultDict(
833+
params=cirq.ParamResolver({}), measurements={"result": np.array([[0, 1, 0]])}
834+
)
835+
# Configure the mock sampler to return valid results for run_batch
836+
mock_sampler.run_batch.return_value = [
837+
[ResultDict(params=cirq.ParamResolver({}), measurements={"result": np.array([[0, 1, 0]])})]
838+
]
839+
840+
# Call measure_pauli_strings with the mock sampler
841+
measure_pauli_strings(circuits_to_pauli, mock_sampler, 100, 100, 0, 1234, use_sweep=use_sweep)
842+
843+
# Verify the sampler was called with the correct circuits
844+
batches = [call.args[0] for call in mock_sampler.run_batch.call_args_list]
845+
called_circuits = [circuit for batch in batches for circuit in batch]
846+
847+
for called_circuit in called_circuits:
848+
measured_qubits = set()
849+
for op in called_circuit.all_operations():
850+
if isinstance(op.gate, cirq.MeasurementGate):
851+
measured_qubits.update(op.qubits)
852+
853+
# Ensure only the qubits in the Pauli string are measured
854+
assert measured_qubits == set(
855+
pauli_qubits
856+
), f"Expected measured qubits: {set(pauli_qubits)}, but found: {measured_qubits}"

0 commit comments

Comments
 (0)