Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
import sympy

import cirq.contrib.shuffle_circuits.shuffle_circuits_with_readout_benchmarking as sc_readout
from cirq import circuits, ops, study, work
from cirq import circuits, ops, work
from cirq.experiments.readout_confusion_matrix import TensoredConfusionMatrices
from cirq.study import ResultDict

if TYPE_CHECKING:
import cirq
from cirq.experiments.single_qubit_readout_calibration import (
SingleQubitReadoutCalibrationResult,
)
Expand Down Expand Up @@ -276,13 +276,14 @@ def _generate_basis_change_circuits(
insert_strategy: circuits.InsertStrategy,
) -> list[circuits.Circuit]:
"""Generates basis change circuits for each group of Pauli strings."""
pauli_measurement_circuits = list[circuits.Circuit]()
pauli_measurement_circuits: list[circuits.Circuit] = []

for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
qid_list = list(sorted(input_circuit.all_qubits()))
basis_change_circuits = []
input_circuit_unfrozen = input_circuit.unfreeze()
for pauli_strings in pauli_string_groups:
# Extract qubits from Pauli strings
qid_list = _extract_readout_qubits(pauli_strings)
basis_change_circuit = circuits.Circuit(
input_circuit_unfrozen,
_pauli_strings_to_basis_change_ops(pauli_strings, qid_list),
Expand All @@ -298,31 +299,28 @@ def _generate_basis_change_circuits(
def _generate_basis_change_circuits_with_sweep(
normalized_circuits_to_pauli: dict[circuits.FrozenCircuit, list[list[ops.PauliString]]],
insert_strategy: circuits.InsertStrategy,
) -> tuple[list[circuits.Circuit], list[study.Sweepable]]:
) -> tuple[list[circuits.Circuit], list[cirq.Sweepable]]:
"""Generates basis change circuits for each group of Pauli strings with sweep."""
parameterized_circuits = list[circuits.Circuit]()
sweep_params = list[study.Sweepable]()
parameterized_circuits: list[circuits.Circuit] = []
sweep_params: list[cirq.Sweepable] = []
for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
qid_list = list(sorted(input_circuit.all_qubits()))
phi_symbols = sympy.symbols(f"phi:{len(qid_list)}")
theta_symbols = sympy.symbols(f"theta:{len(qid_list)}")

# Create phased gates and measurement operator
phased_gates = [
ops.PhasedXPowGate(phase_exponent=(a - 1) / 2, exponent=b)(qubit)
for a, b, qubit in zip(phi_symbols, theta_symbols, qid_list)
]
measurement_op = ops.M(*qid_list, key="result")

parameterized_circuit = circuits.Circuit(
input_circuit.unfreeze(), phased_gates, measurement_op, strategy=insert_strategy
)
sweep_param = []
for pauli_strings in pauli_string_groups:
sweep_param.append(_pauli_strings_to_basis_change_with_sweep(pauli_strings, qid_list))
sweep_params.append(sweep_param)
parameterized_circuits.append(parameterized_circuit)

# Extract qubits from Pauli strings
qid_list = _extract_readout_qubits(pauli_strings)
phi_symbols = sympy.symbols(f"phi:{len(qid_list)}")
theta_symbols = sympy.symbols(f"theta:{len(qid_list)}")
# Create phased gates and measurement operator
phased_gates = [
ops.PhasedXPowGate(phase_exponent=(a - 1) / 2, exponent=b)(qubit)
for a, b, qubit in zip(phi_symbols, theta_symbols, qid_list)
]
measurement_op = ops.M(*qid_list, key="result")
parameterized_circuit = circuits.Circuit(
input_circuit.unfreeze(), phased_gates, measurement_op, strategy=insert_strategy
)
sweep_param = _pauli_strings_to_basis_change_with_sweep(pauli_strings, qid_list)
parameterized_circuits.append(parameterized_circuit)
sweep_params.append(sweep_param)
return parameterized_circuits, sweep_params


Expand Down Expand Up @@ -372,9 +370,8 @@ def _build_many_one_qubits_empty_confusion_matrix(qubits_length: int) -> list[np


def _process_pauli_measurement_results(
qubits: Sequence[ops.Qid],
pauli_string_groups: list[list[ops.PauliString]],
circuit_results: Sequence[ResultDict] | Sequence[study.Result],
circuit_results: Sequence[cirq.ResultDict] | Sequence[cirq.Result],
calibration_results: dict[tuple[ops.Qid, ...], SingleQubitReadoutCalibrationResult],
pauli_repetitions: int,
timestamp: float,
Expand Down Expand Up @@ -419,7 +416,7 @@ def _process_pauli_measurement_results(

for pauli_str in pauli_strs:
qubits_sorted = sorted(pauli_str.qubits)
qubit_indices = [qubits.index(q) for q in qubits_sorted]
qubit_indices = [pauli_readout_qubits.index(q) for q in qubits_sorted]

if disable_readout_mitigation:
pauli_str_calibration_result = None
Expand Down Expand Up @@ -553,8 +550,7 @@ def measure_pauli_strings(

# Build the basis-change circuits for each Pauli string group
pauli_measurement_circuits: list[circuits.Circuit] = []
sweep_params: list[study.Sweepable] = []
circuits_results: Sequence[ResultDict] | Sequence[Sequence[study.Result]] = []
sweep_params: list[cirq.Sweepable] = []
calibration_results: dict[tuple[ops.Qid, ...], SingleQubitReadoutCalibrationResult] = {}

benchmarking_params = sc_readout.ReadoutBenchmarkingParams(
Expand All @@ -569,13 +565,15 @@ def measure_pauli_strings(
)

# Run benchmarking using sweep for readout calibration
circuits_results, calibration_results = sc_readout.run_sweep_with_readout_benchmarking(
sampler=sampler,
input_circuits=pauli_measurement_circuits,
sweep_params=sweep_params,
parameters=benchmarking_params,
rng_or_seed=rng_or_seed,
qubits=[list(qubits) for qubits in qubits_list],
sweep_circuits_results, calibration_results = (
sc_readout.run_sweep_with_readout_benchmarking(
sampler=sampler,
input_circuits=pauli_measurement_circuits,
sweep_params=sweep_params,
parameters=benchmarking_params,
rng_or_seed=rng_or_seed,
qubits=[list(qubits) for qubits in qubits_list],
)
)

else:
Expand All @@ -597,22 +595,20 @@ def measure_pauli_strings(
# Process the results to calculate expectation values
results: list[CircuitToPauliStringsMeasurementResult] = []
circuit_result_index = 0
for i, (input_circuit, pauli_string_groups) in enumerate(normalized_circuits_to_pauli.items()):
qubits_in_circuit = tuple(sorted(input_circuit.all_qubits()))

for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items():
disable_readout_mitigation = False if num_random_bitstrings != 0 else True

circuits_results_for_group: Sequence[ResultDict] | Sequence[study.Result] = []
circuits_results_for_group: Sequence[cirq.ResultDict] | Sequence[cirq.Result] = []
results_slice = slice(circuit_result_index, circuit_result_index + len(pauli_string_groups))
if use_sweep:
circuits_results_for_group = cast(Sequence[Sequence[study.Result]], circuits_results)[i]
circuits_results_for_group = [r[0] for r in sweep_circuits_results[results_slice]]

else:
circuits_results_for_group = cast(Sequence[ResultDict], circuits_results)[
circuit_result_index : circuit_result_index + len(pauli_string_groups)
]
circuit_result_index += len(pauli_string_groups)
circuits_results_for_group = circuits_results[results_slice]

circuit_result_index += len(pauli_string_groups)

pauli_measurement_results = _process_pauli_measurement_results(
list(qubits_in_circuit),
pauli_string_groups,
circuits_results_for_group,
calibration_results,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,54 @@ def test_group_paulis_type_mismatch() -> None:
measure_pauli_strings(
circuits_to_pauli, cirq.Simulator(), 300, 300, 300, np.random.default_rng()
)


@pytest.mark.parametrize("use_sweep", [False, True])
def test_sampler_receives_correct_circuits(use_sweep: bool) -> None:
"""Test that the sampler receives circuits with correct measurement qubits."""

from unittest.mock import MagicMock

from cirq.study.result import ResultDict

# Create a circuit with 5 qubits
qubits = cirq.LineQubit.range(5)

circuit = cirq.FrozenCircuit(_create_ghz(5, qubits))

# Create Pauli strings that only use qubits 1, 2, 3 (indices 1,2,3)
pauli_qubits = qubits[1:4] # qubits 1,2,3
pauli_string: cirq.PauliString = cirq.PauliString({q: cirq.X for q in pauli_qubits})

circuits_to_pauli: dict[cirq.FrozenCircuit, list[cirq.PauliString]] = {}
circuits_to_pauli[circuit] = [pauli_string]

# Mock the sampler
mock_sampler = MagicMock()

# Configure the mock sampler to return a valid ResultDict
mock_sampler.run.return_value = ResultDict(
params=cirq.ParamResolver({}), measurements={"result": np.array([[0, 1, 0]])}
)
# Configure the mock sampler to return valid results for run_batch
mock_sampler.run_batch.return_value = [
[ResultDict(params=cirq.ParamResolver({}), measurements={"result": np.array([[0, 1, 0]])})]
]

# Call measure_pauli_strings with the mock sampler
measure_pauli_strings(circuits_to_pauli, mock_sampler, 100, 100, 0, 1234, use_sweep=use_sweep)

# Verify the sampler was called with the correct circuits
batches = [call.args[0] for call in mock_sampler.run_batch.call_args_list]
called_circuits = [circuit for batch in batches for circuit in batch]

for called_circuit in called_circuits:
measured_qubits = set()
for op in called_circuit.all_operations():
if isinstance(op.gate, cirq.MeasurementGate):
measured_qubits.update(op.qubits)

# Ensure only the qubits in the Pauli string are measured
assert measured_qubits == set(
pauli_qubits
), f"Expected measured qubits: {set(pauli_qubits)}, but found: {measured_qubits}"