diff --git a/cirq-core/cirq/contrib/paulistring/__init__.py b/cirq-core/cirq/contrib/paulistring/__init__.py index cc9248c149a..782f5cb8ed2 100644 --- a/cirq-core/cirq/contrib/paulistring/__init__.py +++ b/cirq-core/cirq/contrib/paulistring/__init__.py @@ -45,4 +45,5 @@ from cirq.contrib.paulistring.pauli_string_measurement_with_readout_mitigation import ( measure_pauli_strings as measure_pauli_strings, + CircuitToPauliStringsParameters as CircuitToPauliStringsParameters, ) diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py index 6c7fc1e6dc5..85060f28dda 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation.py @@ -35,6 +35,40 @@ ) +@attrs.frozen +class CircuitToPauliStringsParameters: + """Parameters for measuring Pauli strings on a circuit. + + Attributes: + circuit: The circuit to measure. + pauli_strings: + - A list of QWC groups (list[list[ops.PauliString]]). Each QWC group + is a list of PauliStrings that are mutually Qubit-Wise Commuting. + Pauli strings within the same group will be calculated using the + same measurement results. + postselection_symmetries: A tuple mapping Pauli strings or Pauli sums to + expected values for postselection symmetries. + Measured bitstrings which do not have the indicated + values of the symmetry operators are postselected out. + """ + + circuit: circuits.FrozenCircuit + pauli_strings: list[list[ops.PauliString]] + postselection_symmetries: list[tuple[ops.PauliString | ops.PauliSum, int]] + + +@attrs.frozen +class PostFilteringSymmetryCalibrationResult: + """Result of post-selection symmetry calibration. + Attributes: + raw_bitstrings: The raw bitstrings obtained from the measurement. + filtered_bitstrings: The bitstrings after applying post-selection symmetries. + """ + + raw_bitstrings: np.ndarray + filtered_bitstrings: np.ndarray + + @attrs.frozen class PauliStringMeasurementResult: """Result of measuring a Pauli string. @@ -45,7 +79,11 @@ class PauliStringMeasurementResult: mitigated_stddev: The standard deviation of the error-mitigated expectation value. unmitigated_expectation: The unmitigated expectation value of the Pauli string. unmitigated_stddev: The standard deviation of the unmitigated expectation value. - calibration_result: The calibration result for single-qubit readout errors. + calibration_result: The calibration result for readout errors. It can be either + a SingleQubitReadoutCalibrationResult (in the case of mitigating with confusion + matrices) or a PostFilteringSymmetryCalibrationResult (in the case of mitigating + with post-selection symmetries). + """ pauli_string: ops.PauliString @@ -53,7 +91,9 @@ class PauliStringMeasurementResult: mitigated_stddev: float unmitigated_expectation: float unmitigated_stddev: float - calibration_result: SingleQubitReadoutCalibrationResult | None = None + calibration_result: ( + SingleQubitReadoutCalibrationResult | PostFilteringSymmetryCalibrationResult | None + ) = None @attrs.frozen @@ -91,6 +131,40 @@ def _are_two_pauli_strings_qubit_wise_commuting( return True +def _are_pauli_sum_and_pauli_string_qubit_wise_commuting( + pauli_sum: ops.PauliSum, + pauli_str: ops.PauliString, + all_qubits: list[ops.Qid] | frozenset[ops.Qid], +) -> bool: + """Checks if a Pauli sum and a Pauli string are Qubit-Wise Commuting.""" + for pauli_sum_term in pauli_sum: + for qubit in all_qubits: + op1 = pauli_sum_term.get(qubit, default=ops.I) + op2 = pauli_str.get(qubit, default=ops.I) + + if not _commute_or_identity(op1, op2): + return False + return True + + +def _are_symmetry_and_pauli_string_qubit_wise_commuting( + symmetry: ops.PauliString | ops.PauliSum, + pauli_str: ops.PauliString, + all_qubits: list[ops.Qid] | frozenset[ops.Qid], +) -> bool: + """Checks if a symmetry (Pauli string or Pauli sum) and a Pauli string + are Qubit-Wise Commuting. This is necessary because the code's + post-selection method relies on measuring both the symmetry and the + Pauli string at the same time, using a single experimental shot. + """ + if isinstance(symmetry, ops.PauliSum): + return _are_pauli_sum_and_pauli_string_qubit_wise_commuting(symmetry, pauli_str, all_qubits) + elif isinstance(symmetry, ops.PauliString): + return _are_two_pauli_strings_qubit_wise_commuting(symmetry, pauli_str, all_qubits) + else: + return False + + def _validate_group_paulis_qwc( pauli_strs: list[ops.PauliString], all_qubits: list[ops.Qid] | frozenset[ops.Qid] ): @@ -132,10 +206,79 @@ def _validate_single_pauli_string(pauli_str: ops.PauliString): ) +def _validate_circuit_to_pauli_strings_parameters( + circuits_to_pauli: list[CircuitToPauliStringsParameters], +): + """Validates the input parameters for measuring Pauli strings. + + Args: + circuits_to_pauli: A list of CircuitToPauliStringsParameters objects. + + Raises: + ValueError: If any of the input parameters are invalid. + TypeError: If the types of the input parameters are incorrect. + """ + for i, params in enumerate(circuits_to_pauli): + # 1. Validate Circuit + if not params.circuit: + raise ValueError(f"Item {i}: Circuit must not be empty.") + if not isinstance(params.circuit, circuits.FrozenCircuit): + raise TypeError( + f"Item {i}: Expected circuit to be FrozenCircuit, got {type(params.circuit)}." + ) + + # 2. Validate Pauli strings + for j, pauli_group in enumerate(params.pauli_strings): + if not pauli_group: + raise ValueError( + f"Item {i}, group {j}: Empty group of Pauli strings is not allowed." + ) + if not _validate_group_paulis_qwc(pauli_group, params.circuit.all_qubits()): + raise ValueError( + f"Item {i}, group {j}: Pauli group {pauli_group} is not " + "Qubit-Wise Commuting." + ) + for pauli_str in pauli_group: + _validate_single_pauli_string(pauli_str) + + # 3. Validate postselection symmetries + for sym, _ in params.postselection_symmetries: + if isinstance(sym, ops.PauliSum): + terms = list(sym) + if not _validate_group_paulis_qwc(terms, params.circuit.all_qubits()): + raise ValueError( + f"Pauli sum {sym} for circuit {params.circuit} is invalid: " + "Terms are not Qubit-Wise Commuting." + ) + for term in terms: + _validate_single_pauli_string(term) + elif isinstance(sym, ops.PauliString): + _validate_single_pauli_string(sym) + else: + raise TypeError( + f"Postselection symmetry keys must be cirq.PauliString or cirq.PauliSum, " + f"got {type(sym)}." + ) + + # Check if input symmetries are commuting with all Pauli strings in the circuit + qubits_in_circuit = list(sorted(params.circuit.all_qubits())) + + if not all( + _are_symmetry_and_pauli_string_qubit_wise_commuting(sym, pauli_str, qubits_in_circuit) + for pauli_strs in params.pauli_strings + for pauli_str in pauli_strs + for sym, _ in params.postselection_symmetries + ): + raise ValueError( + f"Postselection symmetries of {params.circuit} are not commuting with all Pauli" + ) + + def _validate_input( circuits_to_pauli: ( dict[circuits.FrozenCircuit, list[ops.PauliString]] | dict[circuits.FrozenCircuit, list[list[ops.PauliString]]] + | list[CircuitToPauliStringsParameters] ), pauli_repetitions: int, readout_repetitions: int, @@ -143,45 +286,11 @@ def _validate_input( rng_or_seed: np.random.Generator | int, ): if not circuits_to_pauli: - raise ValueError("Input circuits must not be empty.") + raise ValueError("Input circuits_to_pauli parameter must not be empty.") - for circuit in circuits_to_pauli.keys(): - if not isinstance(circuit, circuits.FrozenCircuit): - raise TypeError("All keys in 'circuits_to_pauli' must be FrozenCircuit instances.") + normalized_circuits_to_pauli = _validate_and_normalize_unformatted_input(circuits_to_pauli) - first_value: list[ops.PauliString] | list[list[ops.PauliString]] = next( - iter(circuits_to_pauli.values()) # type: ignore - ) - for circuit, pauli_strs_list in circuits_to_pauli.items(): - if isinstance(pauli_strs_list, Sequence) and isinstance(first_value[0], Sequence): - for pauli_strs in pauli_strs_list: - if not pauli_strs: - raise ValueError("Empty group of Pauli strings is not allowed") - if not ( - isinstance(pauli_strs, Sequence) and isinstance(pauli_strs[0], ops.PauliString) - ): - raise TypeError( - f"Inconsistent type in list for circuit {circuit}. " - f"Expected all elements to be sequences of ops.PauliString, " - f"but found {type(pauli_strs)}." - ) - if not _validate_group_paulis_qwc(pauli_strs, circuit.all_qubits()): - raise ValueError( - f"Pauli group containing {pauli_strs} is invalid: " - f"The group of Pauli strings are not " - f"Qubit-Wise Commuting with each other." - ) - for pauli_str in pauli_strs: - _validate_single_pauli_string(pauli_str) - elif isinstance(pauli_strs_list, Sequence) and isinstance(first_value[0], ops.PauliString): - for pauli_str in pauli_strs_list: # type: ignore - _validate_single_pauli_string(pauli_str) - else: - raise TypeError( - f"Expected all elements to be either a sequence of PauliStrings" - f" or sequences of ops.PauliStrings. " - f"Got {type(pauli_strs_list)} instead." - ) + _validate_circuit_to_pauli_strings_parameters(normalized_circuits_to_pauli) # Check rng is a numpy random generator if not isinstance(rng_or_seed, np.random.Generator) and not isinstance(rng_or_seed, int): @@ -199,25 +308,59 @@ def _validate_input( if readout_repetitions <= 0: raise ValueError("Must provide positive readout_repetitions for readout calibration.") + return normalized_circuits_to_pauli -def _normalize_input_paulis( - circuits_to_pauli: ( + +def _validate_and_normalize_unformatted_input( + circuits_input: ( dict[circuits.FrozenCircuit, list[ops.PauliString]] | dict[circuits.FrozenCircuit, list[list[ops.PauliString]]] + | list[CircuitToPauliStringsParameters] ), -) -> dict[circuits.FrozenCircuit, list[list[ops.PauliString]]]: - first_value = next(iter(circuits_to_pauli.values())) - if ( - first_value - and isinstance(first_value, list) - and isinstance(first_value[0], ops.PauliString) - ): - input_dict = cast(dict[circuits.FrozenCircuit, list[ops.PauliString]], circuits_to_pauli) - normalized_circuits_to_pauli: dict[circuits.FrozenCircuit, list[list[ops.PauliString]]] = {} - for circuit, paulis in input_dict.items(): - normalized_circuits_to_pauli[circuit] = [[ps] for ps in paulis] - return normalized_circuits_to_pauli - return cast(dict[circuits.FrozenCircuit, list[list[ops.PauliString]]], circuits_to_pauli) +) -> list[CircuitToPauliStringsParameters]: + """Converts any valid input format into a standardized list of parameters + where pauli_strings is always list[list[PauliString]].""" + + param_list: list[CircuitToPauliStringsParameters] = [] + + # 1. Standardize to list[CircuitToPauliStringsParameters] + if isinstance(circuits_input, dict): + for circuit, paulis in circuits_input.items(): + # Normalize flat lists to nested lists + normalized_paulis = paulis + if paulis and isinstance(paulis, list) and isinstance(paulis[0], ops.PauliString): + # Convert [PS, PS] -> [[PS], [PS]] + normalized_paulis = [[cast(ops.PauliString, ps)] for ps in paulis] + + param_list.append( + CircuitToPauliStringsParameters( + circuit=circuit, + pauli_strings=cast(list[list[ops.PauliString]], normalized_paulis), + postselection_symmetries=[], + ) + ) + elif isinstance(circuits_input, list): + param_list = circuits_input + else: + raise TypeError("Input must be a dict or a list of CircuitToPauliStringsParameters.") + + for params in param_list: + if not ( + params.pauli_strings + and isinstance(params.pauli_strings, list) + and all(isinstance(params.pauli_strings, list) for _ in params.pauli_strings) + and all( + isinstance(ps, ops.PauliString) + for ps_list in params.pauli_strings + for ps in ps_list + ) + ): + raise TypeError( + "Expected all elements to be list[list[ops.PauliString]], " + f"but got {type(params.pauli_strings)}." + ) + + return param_list def _extract_readout_qubits(pauli_strings: list[ops.PauliString]) -> list[ops.Qid]: @@ -271,13 +414,16 @@ def _pauli_strings_to_basis_change_with_sweep( def _generate_basis_change_circuits( - normalized_circuits_to_pauli: dict[circuits.FrozenCircuit, list[list[ops.PauliString]]], + normalized_circuits_to_pauli: list[CircuitToPauliStringsParameters], insert_strategy: circuits.InsertStrategy, ) -> list[circuits.Circuit]: """Generates basis change circuits for each group of Pauli strings.""" pauli_measurement_circuits = list[circuits.Circuit]() - for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items(): + for params in normalized_circuits_to_pauli: + input_circuit = params.circuit + pauli_string_groups = params.pauli_strings + qid_list = list(sorted(input_circuit.all_qubits())) basis_change_circuits = [] input_circuit_unfrozen = input_circuit.unfreeze() @@ -295,13 +441,16 @@ def _generate_basis_change_circuits( def _generate_basis_change_circuits_with_sweep( - normalized_circuits_to_pauli: dict[circuits.FrozenCircuit, list[list[ops.PauliString]]], + normalized_circuits_to_pauli: list[CircuitToPauliStringsParameters], insert_strategy: circuits.InsertStrategy, ) -> tuple[list[circuits.Circuit], list[study.Sweepable]]: """Generates basis change circuits for each group of Pauli strings with sweep.""" parameterized_circuits = list[circuits.Circuit]() sweep_params = list[study.Sweepable]() - for input_circuit, pauli_string_groups in normalized_circuits_to_pauli.items(): + for params in normalized_circuits_to_pauli: + input_circuit = params.circuit + pauli_string_groups = params.pauli_strings + qid_list = list(sorted(input_circuit.all_qubits())) phi_symbols = sympy.symbols(f"phi:{len(qid_list)}") theta_symbols = sympy.symbols(f"theta:{len(qid_list)}") @@ -482,6 +631,7 @@ def measure_pauli_strings( circuits_to_pauli: ( dict[circuits.FrozenCircuit, list[ops.PauliString]] | dict[circuits.FrozenCircuit, list[list[ops.PauliString]]] + | list[CircuitToPauliStringsParameters] ), sampler: work.Sampler, pauli_repetitions: int, @@ -494,7 +644,6 @@ def measure_pauli_strings( """Measures expectation values of Pauli strings on given circuits with/without readout error mitigation. - This function takes a dictionary mapping circuits to lists of QWC Pauli string groups. For each circuit and its associated list of QWC pauli string group, it: 1. Constructs circuits to measure the Pauli string expectation value by adding basis change moments and measurement operations. @@ -512,6 +661,10 @@ def measure_pauli_strings( same measurement results. - A list of PauliStrings (list[ops.PauliString]). In this case, each PauliString is treated as its own measurement group. + - Or a list of CircuitToPauliStringsParameters objects. Each object contains + a circuit and its associated Pauli strings to measure. It could also contain + a dictionary mapping Pauli strings or Pauli sums to expected eigen value + for postselection symmetries. sampler: The sampler to use. pauli_repetitions: The number of repetitions for each circuit when measuring Pauli strings. @@ -532,7 +685,7 @@ def measure_pauli_strings( - The calibration result for single-qubit readout errors. """ - _validate_input( + normalized_circuits_to_pauli = _validate_input( circuits_to_pauli, pauli_repetitions, readout_repetitions, @@ -540,13 +693,11 @@ def measure_pauli_strings( rng_or_seed, ) - normalized_circuits_to_pauli = _normalize_input_paulis(circuits_to_pauli) - # Extract unique qubit tuples from input pauli strings unique_qubit_tuples = set() - for pauli_string_groups in normalized_circuits_to_pauli.values(): - for pauli_strings in pauli_string_groups: - unique_qubit_tuples.add(tuple(_extract_readout_qubits(pauli_strings))) + for circuit_to_pauli in normalized_circuits_to_pauli: + for pauli_string_groups in circuit_to_pauli.pauli_strings: + unique_qubit_tuples.add(tuple(_extract_readout_qubits(pauli_string_groups))) # qubits_list is a list of qubit tuples qubits_list = sorted(unique_qubit_tuples) @@ -596,7 +747,10 @@ def measure_pauli_strings( # Process the results to calculate expectation values results: list[CircuitToPauliStringsMeasurementResult] = [] circuit_result_index = 0 - for i, (input_circuit, pauli_string_groups) in enumerate(normalized_circuits_to_pauli.items()): + for i, circuit_to_pauli in enumerate(normalized_circuits_to_pauli): + input_circuit = circuit_to_pauli.circuit + pauli_string_groups = circuit_to_pauli.pauli_strings + qubits_in_circuit = tuple(sorted(input_circuit.all_qubits())) disable_readout_mitigation = False if num_random_bitstrings != 0 else True diff --git a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py index da367c4cfc5..e28630b34f4 100644 --- a/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py +++ b/cirq-core/cirq/contrib/paulistring/pauli_string_measurement_with_readout_mitigation_test.py @@ -22,7 +22,7 @@ import pytest import cirq -from cirq.contrib.paulistring import measure_pauli_strings +from cirq.contrib.paulistring import CircuitToPauliStringsParameters, measure_pauli_strings from cirq.experiments import SingleQubitReadoutCalibrationResult from cirq.experiments.single_qubit_readout_calibration_test import NoisySingleQubitReadoutSampler @@ -326,8 +326,16 @@ def test_many_circuits_input_measurement_with_noise(use_sweep: bool) -> None: circuit_2 = cirq.FrozenCircuit(_create_ghz(5, qubits_2)) circuit_3 = cirq.FrozenCircuit(_create_ghz(8, qubits_3)) - circuits_to_pauli: dict[cirq.FrozenCircuit, list[cirq.PauliString]] = {} - circuits_to_pauli[circuit_1] = [_generate_random_pauli_string(qubits_1) for _ in range(3)] + circuits_to_pauli: dict[ + cirq.FrozenCircuit, list[cirq.PauliString] | list[list[cirq.PauliString]] + ] = {} + # This is to test mixed types could be handled. + circuits_to_pauli[circuit_1] = [ + _generate_qwc_paulis( + _generate_random_pauli_string(qubits_1, enable_coeff=True, allow_pauli_i=False), 6 + ) + for _ in range(3) + ] circuits_to_pauli[circuit_2] = [_generate_random_pauli_string(qubits_2) for _ in range(3)] circuits_to_pauli[circuit_3] = [_generate_random_pauli_string(qubits_3) for _ in range(3)] @@ -335,7 +343,13 @@ def test_many_circuits_input_measurement_with_noise(use_sweep: bool) -> None: simulator = cirq.Simulator() circuits_with_pauli_expectations = measure_pauli_strings( - circuits_to_pauli, sampler, 300, 300, 300, np.random.default_rng(), use_sweep + circuits_to_pauli, # type: ignore + sampler, + 300, + 300, + 300, + np.random.default_rng(), + use_sweep, ) for circuit_with_pauli_expectations in circuits_with_pauli_expectations: @@ -570,62 +584,60 @@ def test_coefficient_not_real_number() -> None: def test_empty_input_circuits_to_pauli_mapping() -> None: """Test that the input circuits are empty.""" - with pytest.raises(ValueError, match="Input circuits must not be empty."): - measure_pauli_strings( - [], cirq.Simulator(), 300, 300, 300, np.random.default_rng() # type: ignore[arg-type] - ) + with pytest.raises(ValueError, match="Input circuits_to_pauli parameter must not be empty"): + measure_pauli_strings([], cirq.Simulator(), 300, 300, 300, np.random.default_rng()) -def test_invalid_input_circuit_type() -> None: - """Test that the input circuit type is not frozen circuit""" - qubits = cirq.LineQubit.range(5) +def test_invalid_input_container_type() -> None: + """Test that passing an invalid container type raises TypeError.""" + qubits = cirq.LineQubit.range(2) + circuit = cirq.FrozenCircuit(_create_ghz(2, qubits)) - qubits_to_pauli: dict[tuple, list[cirq.PauliString]] = {} - qubits_to_pauli[tuple(qubits)] = [cirq.PauliString({q: cirq.X for q in qubits})] - with pytest.raises( - TypeError, match="All keys in 'circuits_to_pauli' must be FrozenCircuit instances." - ): + invalid_input = {circuit} + + with pytest.raises(TypeError, match="Input must be a dict or a list"): measure_pauli_strings( - qubits_to_pauli, # type: ignore[arg-type] - cirq.Simulator(), - 300, - 300, - 300, - np.random.default_rng(), + invalid_input, cirq.Simulator(), 100, 100, 100, np.random.default_rng() # type: ignore ) -def test_invalid_input_pauli_string_type() -> None: - """Test input circuit is not mapping to a paulistring""" - qubits_1 = cirq.LineQubit.range(5) - qubits_2 = [ - cirq.GridQubit(0, 1), - cirq.GridQubit(1, 1), - cirq.GridQubit(1, 0), - cirq.GridQubit(1, 2), - cirq.GridQubit(2, 1), - ] - - circuit_1 = cirq.FrozenCircuit(_create_ghz(5, qubits_1)) - circuit_2 = cirq.FrozenCircuit(_create_ghz(5, qubits_2)) +def test_circuit_parameters_validation_errors() -> None: + """Test validation errors specific to CircuitToPauliStringsParameters attributes.""" + q0 = cirq.LineQubit(0) + valid_circuit = cirq.FrozenCircuit(cirq.Circuit(cirq.X(q0))) + valid_pauli: list[list[cirq.PauliString]] = [[cirq.PauliString(cirq.Z(q0))]] - circuits_to_pauli: dict[cirq.FrozenCircuit, cirq.FrozenCircuit] = {} - circuits_to_pauli[circuit_1] = [_generate_random_pauli_string(qubits_1)] # type: ignore - circuits_to_pauli[circuit_2] = [circuit_1, circuit_2] # type: ignore + sampler = cirq.Simulator() + rng = np.random.default_rng() + # Test empty circuit + params_empty_circuit = CircuitToPauliStringsParameters( + circuit=cirq.FrozenCircuit(), # Empty + pauli_strings=valid_pauli, + postselection_symmetries=[], + ) + with pytest.raises(ValueError, match="Circuit must not be empty"): + measure_pauli_strings([params_empty_circuit], sampler, 10, 10, 10, rng) + + # Test Invalid Type for Circuit + params_invalid_circuit_type = CircuitToPauliStringsParameters( + circuit="NotACircuit", # type: ignore + pauli_strings=valid_pauli, + postselection_symmetries=[], + ) + with pytest.raises(TypeError, match="Expected circuit to be FrozenCircuit"): + measure_pauli_strings([params_invalid_circuit_type], sampler, 10, 10, 10, rng) + + # Test Invalid Type in Pauli Strings + params_invalid_type = CircuitToPauliStringsParameters( + circuit=valid_circuit, + pauli_strings=[["NotAPauliString"]], # type: ignore + postselection_symmetries=[], + ) with pytest.raises( - TypeError, - match="All elements in the Pauli string lists must be cirq.PauliString " - "instances, got .", + TypeError, match=r"Expected all elements to be list\[list\[ops.PauliString\]\]" ): - measure_pauli_strings( - circuits_to_pauli, # type: ignore[arg-type] - cirq.Simulator(), - 300, - 300, - 300, - np.random.default_rng(), - ) + measure_pauli_strings([params_invalid_type], sampler, 10, 10, 10, rng) def test_all_pauli_strings_are_pauli_i() -> None: @@ -718,24 +730,6 @@ def test_rng_type_mismatch() -> None: ) -def test_pauli_type_mismatch() -> None: - """Test that the input paulis are not a sequence of PauliStrings.""" - qubits = cirq.LineQubit.range(5) - - circuit = cirq.FrozenCircuit(_create_ghz(5, qubits)) - - circuits_to_pauli: dict[cirq.FrozenCircuit, int] = {} - circuits_to_pauli[circuit] = 1 - with pytest.raises( - TypeError, - match="Expected all elements to be either a sequence of PauliStrings or sequences of" - " ops.PauliStrings. Got instead.", - ): - measure_pauli_strings( - circuits_to_pauli, cirq.Simulator(), 300, 300, 300, 1234 # type: ignore[arg-type] - ) - - def test_group_paulis_are_not_qwc() -> None: """Test that the group paulis are not qwc.""" qubits = cirq.LineQubit.range(5) @@ -747,9 +741,7 @@ def test_group_paulis_are_not_qwc() -> None: circuits_to_pauli: dict[cirq.FrozenCircuit, list[cirq.PauliString]] = {} circuits_to_pauli[circuit] = [[pauli_str1, pauli_str2]] # type: ignore - with pytest.raises( - ValueError, match="The group of Pauli strings are not Qubit-Wise Commuting with each other." - ): + with pytest.raises(ValueError, match="is not Qubit-Wise Commuting."): measure_pauli_strings( circuits_to_pauli, cirq.Simulator(), 300, 300, 300, np.random.default_rng() ) @@ -769,37 +761,60 @@ def test_empty_group_paulis_not_allowed() -> None: ) -def test_group_paulis_type_mismatch() -> None: - """Test that the group paulis type is not correct""" - qubits_1 = cirq.LineQubit.range(3) - qubits_2 = [ - cirq.GridQubit(0, 1), - cirq.GridQubit(1, 1), - cirq.GridQubit(1, 0), - cirq.GridQubit(1, 2), - cirq.GridQubit(2, 1), - ] - qubits_3 = cirq.LineQubit.range(8) +def test_postselection_symmetry_validation_and_logic() -> None: + """Test validation and QWC logic for post-selection symmetries.""" + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.FrozenCircuit(cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1))) - circuit_1 = cirq.FrozenCircuit(_create_ghz(3, qubits_1)) - circuit_2 = cirq.FrozenCircuit(_create_ghz(5, qubits_2)) - circuit_3 = cirq.FrozenCircuit(_create_ghz(8, qubits_3)) + # Target Pauli String to measure: Z0 * Z1 + target_paulis: list[list[cirq.PauliString]] = [[cirq.PauliString(cirq.Z(q0) * cirq.Z(q1))]] - circuits_to_pauli: dict[cirq.FrozenCircuit, list[list[cirq.PauliString]]] = {} - circuits_to_pauli[circuit_1] = [ - _generate_qwc_paulis( - _generate_random_pauli_string(qubits_1, enable_coeff=True, allow_pauli_i=False), 6 - ) - for _ in range(3) - ] - circuits_to_pauli[circuit_2] = [_generate_random_pauli_string(qubits_2, True) for _ in range(3)] - circuits_to_pauli[circuit_3] = [_generate_random_pauli_string(qubits_3, True) for _ in range(3)] + sampler = cirq.Simulator() + rng = np.random.default_rng() - with pytest.raises( - TypeError, - match="Expected all elements to be sequences of ops.PauliString, " - "but found .", - ): - measure_pauli_strings( - circuits_to_pauli, cirq.Simulator(), 300, 300, 300, np.random.default_rng() - ) + # Test Valid PauliSum Symmetry + # Z0 commutes with Z0*Z1. Z1 commutes with Z0*Z1. + valid_sum_sym = cirq.PauliSum.from_pauli_strings( + [cirq.PauliString(cirq.Z(q0)), cirq.PauliString(cirq.Z(q1))] + ) + valid_pauli_sym: cirq.PauliString = cirq.PauliString(cirq.Z(q0)) + params_valid_sum = CircuitToPauliStringsParameters( + circuit=circuit, + pauli_strings=target_paulis, + postselection_symmetries=[(valid_sum_sym, 1), (valid_pauli_sym, 1)], + ) + # Should not raise any error + measure_pauli_strings([params_valid_sum], sampler, 10, 10, 0, rng) + + # Test PauliSum with Non-QWC Terms + # X0 and Z0 do not commute. This is an invalid PauliSum *structure* for this context. + invalid_structure_sum = cirq.PauliSum.from_pauli_strings( + [cirq.PauliString(cirq.X(q0)), cirq.PauliString(cirq.Z(q0))] + ) + params_bad_sum_structure = CircuitToPauliStringsParameters( + circuit=circuit, + pauli_strings=target_paulis, + postselection_symmetries=[(invalid_structure_sum, 1)], + ) + with pytest.raises(ValueError, match="Terms are not Qubit-Wise Commuting"): + measure_pauli_strings([params_bad_sum_structure], sampler, 10, 10, 0, rng) + + # Test Invalid Symmetry Type + params_bad_type = CircuitToPauliStringsParameters( + circuit=circuit, + pauli_strings=target_paulis, + postselection_symmetries=[("NotASymmetry", 1)], # type: ignore + ) + with pytest.raises(TypeError, match="must be cirq.PauliString or cirq.PauliSum"): + measure_pauli_strings([params_bad_type], sampler, 10, 10, 0, rng) + + # Test PauliSum NOT Commuting with Target + # X0 does not commute with Z0*Z1. + non_commuting_sum = cirq.PauliSum.from_pauli_strings([cirq.PauliString(cirq.X(q0))]) + params_non_commute = CircuitToPauliStringsParameters( + circuit=circuit, + pauli_strings=target_paulis, + postselection_symmetries=[(non_commuting_sum, 1)], + ) + with pytest.raises(ValueError, match="not commuting with all Pauli"): + measure_pauli_strings([params_non_commute], sampler, 10, 10, 0, rng)