diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 15152fc2c6e..8e5724aca64 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1767,6 +1767,7 @@ class Circuit(AbstractCircuit): * batch_remove * batch_insert_into * insert_at_frontier + * reverse Circuits can also be iterated over, @@ -2525,6 +2526,16 @@ def clear_operations_touching( self._moments[k] = self._moments[k].without_operations_touching(qubits) self._mutated() + def reverse(self) -> None: + """Reverses the moments in the circuit, and the operations in the moments.""" + # Work on a copy in case validation fails halfway through. + copy = self.copy() + backwards = [] + for moment in copy[::-1]: + backwards.append(Moment(reversed(moment.operations))) + self._moments = backwards + self._mutated() + @property def moments(self) -> Sequence[cirq.Moment]: return self._moments diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index cd8f7b00c70..3c60c06c074 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -1868,6 +1868,179 @@ def test_clear_operations_touching() -> None: ) +def test_reverse_empty_circuit(): + circuit = cirq.Circuit() + circuit.reverse() + assert len(circuit) == 0 + assert circuit == cirq.Circuit() + + +def test_reverse_single_moment_single_operation(): + q = cirq.GridQubit(0, 0) + circuit = cirq.Circuit(cirq.X(q)) + original_str = str(circuit) + + circuit.reverse() + + assert str(circuit) == original_str + assert len(circuit) == 1 + + +def test_reverse_single_moment_multiple_operations(): + """Test reversing a circuit with one moment and multiple operations.""" + q0, q1, q2 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1), cirq.GridQubit(0, 2) + original_ops = [cirq.X(q0), cirq.Y(q1), cirq.Z(q2)] + circuit = cirq.Circuit(cirq.Moment(original_ops)) + + circuit.reverse() + + # Moment order unchanged (only one moment), but operations reversed + assert len(circuit) == 1 + reversed_ops = list(circuit[0]) + assert reversed_ops == list(reversed(original_ops)) + + +def test_reverse_multiple_moments_single_operations(): + """Test reversing a circuit with multiple moments, each with single operations.""" + q = cirq.GridQubit(0, 0) + circuit = cirq.Circuit([cirq.Moment([cirq.X(q)]), cirq.Moment([cirq.Y(q)]), cirq.Moment([cirq.Z(q)])]) + + original_moments = [str(moment) for moment in circuit] + circuit.reverse() + + # Moments should be reversed + assert len(circuit) == 3 + reversed_moments = [str(moment) for moment in circuit] + assert reversed_moments == list(reversed(original_moments)) + +def test_reverse_multiple_moments_multiple_operations(): + """Test reversing a circuit with multiple moments and multiple operations.""" + q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1) + circuit = cirq.Circuit( + [ + cirq.Moment([cirq.X(q0), cirq.Y(q1)]), + cirq.Moment([cirq.Z(q0), cirq.H(q1)]), + cirq.Moment([cirq.S(q0), cirq.T(q1)]) + ] + ) + + # Store original structure + original_structure = [] + for moment in circuit: + original_structure.append(list(moment.operations)) + + circuit.reverse() + + # Check that moments are reversed and operations within each moment are reversed + assert len(circuit) == 3 + + # First moment should be the reversed last moment + expected_first = list(reversed(original_structure[2])) + actual_first = list(circuit[0]) + assert actual_first == expected_first + + # Second moment should be the reversed middle moment + expected_second = list(reversed(original_structure[1])) + actual_second = list(circuit[1]) + assert actual_second == expected_second + + # Third moment should be the reversed first moment + expected_third = list(reversed(original_structure[0])) + actual_third = list(circuit[2]) + assert actual_third == expected_third + + +def test_reverse_twice_returns_original(): + """Test that reversing twice returns the original circuit.""" + q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1) + original_circuit = cirq.Circuit([ + cirq.Moment([cirq.X(q0), cirq.Y(q1)]), + cirq.Moment([cirq.Z(q0)]), + cirq.Moment([cirq.H(q0), cirq.S(q1)]) + ] + ) + + # Make a copy to compare against + expected = original_circuit.copy() + + # Reverse twice + original_circuit.reverse() + original_circuit.reverse() + + # Should be back to original + assert original_circuit == expected + + +def test_reverse_with_measurements(): + """Test reversing a circuit with measurement operations.""" + q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1) + circuit = cirq.Circuit( + [ + cirq.Moment([cirq.X(q0), cirq.Y(q1)]), + cirq.Moment([cirq.measure(q0, key='a'), cirq.measure(q1, key='b')]) + ] + ) + + original_structure = [] + for moment in circuit: + original_structure.append(list(moment.operations)) + + circuit.reverse() + + # Check structure is properly reversed + assert len(circuit) == 2 + + # First moment should be reversed measurements + actual_first = list(circuit[0]) + assert len(actual_first) == 2 + assert all(isinstance(op.gate, cirq.MeasurementGate) for op in actual_first) + + # Second moment should be reversed X, Y gates + actual_second = list(circuit[1]) + assert len(actual_second) == 2 + + +def test_reverse_with_two_qubit_gates(): + """Test reversing a circuit with two-qubit gates.""" + q0, q1, q2 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1), cirq.GridQubit(0, 2) + circuit = cirq.Circuit( + [ + cirq.Moment([cirq.CNOT(q0, q1), cirq.X(q2)]), + cirq.Moment([cirq.CZ(q1, q2)]), + cirq.Moment([cirq.SWAP(q0, q2), cirq.Y(q1)]) + ] + ) + + original_structure = [] + for moment in circuit: + original_structure.append(list(moment.operations)) + + circuit.reverse() + + # Verify the structure is correctly reversed + assert len(circuit) == 3 + + # Check that two-qubit gates are preserved correctly + for i, moment in enumerate(circuit): + expected_ops = list(reversed(original_structure[2 - i])) + actual_ops = list(moment.operations) + assert actual_ops == expected_ops + + +def test_reverse_modifies_original_circuit(): + """Test that reverse() modifies the original circuit in-place.""" + q = cirq.GridQubit(0, 0) + circuit = cirq.Circuit([cirq.Moment([cirq.X(q)]), cirq.Moment([cirq.Y(q)])]) + + original_id = id(circuit) + circuit.reverse() + + # Should be the same object + assert id(circuit) == original_id + + # But content should be different + assert str(circuit[0]) != "X(q(0, 0))" # First moment is now Y + @pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit]) def test_all_qubits(circuit_cls) -> None: a = cirq.NamedQubit('a') diff --git a/cirq-core/cirq/transformers/stratify.py b/cirq-core/cirq/transformers/stratify.py index 23d40702d7b..e854dc207de 100644 --- a/cirq-core/cirq/transformers/stratify.py +++ b/cirq-core/cirq/transformers/stratify.py @@ -19,6 +19,8 @@ import itertools from typing import Callable, Iterable, Sequence, TYPE_CHECKING, Union +import copy + from cirq import _import, circuits, ops, protocols from cirq.transformers import transformer_api @@ -69,7 +71,8 @@ def stratified_circuit( # Try the algorithm with each permutation of the classifiers. smallest_depth = protocols.num_qubits(circuit) * len(circuit) + 1 shortest_stratified_circuit = circuits.Circuit() - reversed_circuit = circuit[::-1] + reversed_circuit = copy.deepcopy(circuit) + reversed_circuit.reverse() for ordered_classifiers in itertools.permutations(classifiers): solution = _stratify_circuit( circuit, @@ -87,7 +90,8 @@ def stratified_circuit( reversed_circuit, classifiers=ordered_classifiers, context=context or transformer_api.TransformerContext(), - )[::-1] + ) + solution.reverse() if len(solution) < smallest_depth: shortest_stratified_circuit = solution smallest_depth = len(solution)