diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index eb3a8de36e8..1f26fcab78c 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -356,20 +356,21 @@ def test_add_op_tree(circuit_cls) -> None: @pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit]) -def test_radd_op_tree(circuit_cls) -> None: +@pytest.mark.parametrize('gate', [cirq.X, cirq.H]) +def test_radd_op_tree(circuit_cls, gate) -> None: a = cirq.NamedQubit('a') b = cirq.NamedQubit('b') c = circuit_cls() - assert [cirq.X(a), cirq.Y(b)] + c == circuit_cls([cirq.Moment([cirq.X(a), cirq.Y(b)])]) + assert [gate(a), cirq.Y(b)] + c == circuit_cls([cirq.Moment([gate(a), cirq.Y(b)])]) - assert cirq.X(a) + c == circuit_cls(cirq.X(a)) - assert [cirq.X(a)] + c == circuit_cls(cirq.X(a)) - assert [[[cirq.X(a)], []]] + c == circuit_cls(cirq.X(a)) - assert (cirq.X(a),) + c == circuit_cls(cirq.X(a)) - assert (cirq.X(a) for _ in range(1)) + c == circuit_cls(cirq.X(a)) + assert gate(a) + c == circuit_cls(gate(a)) + assert [gate(a)] + c == circuit_cls(gate(a)) + assert [[[gate(a)], []]] + c == circuit_cls(gate(a)) + assert (gate(a),) + c == circuit_cls(gate(a)) + assert (gate(a) for _ in range(1)) + c == circuit_cls(gate(a)) with pytest.raises(AttributeError): - _ = cirq.X + c + _ = gate + c with pytest.raises(TypeError): _ = 0 + c @@ -380,9 +381,9 @@ def test_radd_op_tree(circuit_cls) -> None: else: d = cirq.Circuit() d.append(cirq.Y(b)) - assert [cirq.X(a)] + d == circuit_cls([cirq.Moment([cirq.X(a)]), cirq.Moment([cirq.Y(b)])]) - assert cirq.Moment([cirq.X(a)]) + d == circuit_cls( - [cirq.Moment([cirq.X(a)]), cirq.Moment([cirq.Y(b)])] + assert [gate(a)] + d == circuit_cls([cirq.Moment([gate(a)]), cirq.Moment([cirq.Y(b)])]) + assert cirq.Moment([gate(a)]) + d == circuit_cls( + [cirq.Moment([gate(a)]), cirq.Moment([cirq.Y(b)])] ) diff --git a/cirq-core/cirq/ops/gate_operation.py b/cirq-core/cirq/ops/gate_operation.py index 9c8f2ed8657..33929591a2f 100644 --- a/cirq-core/cirq/ops/gate_operation.py +++ b/cirq-core/cirq/ops/gate_operation.py @@ -16,6 +16,7 @@ from __future__ import annotations +import numbers import re import warnings from collections.abc import Collection, Mapping, Sequence, Set @@ -323,12 +324,16 @@ def __rmul__(self, other: Any) -> Any: return self.gate._rmul_with_qubits(self._qubits, other) def __add__(self, other): + if not isinstance(other, (ops.Operation, numbers.Number)): + return NotImplemented return 1 * self + other def __radd__(self, other): return other + 1 * self def __sub__(self, other): + if not isinstance(other, (ops.Operation, numbers.Number)): + return NotImplemented return 1 * self - other def __rsub__(self, other): diff --git a/cirq-core/cirq/ops/linear_combinations.py b/cirq-core/cirq/ops/linear_combinations.py index 275d8e4782f..eae0def1d88 100644 --- a/cirq-core/cirq/ops/linear_combinations.py +++ b/cirq-core/cirq/ops/linear_combinations.py @@ -743,6 +743,8 @@ def __len__(self) -> int: return len(self._linear_dict) def __iadd__(self, other): + if isinstance(other, raw_types.Operation): + other = pauli_string._try_interpret_as_pauli_string(other) if isinstance(other, numbers.Complex): other = PauliSum.from_pauli_strings([PauliString(coefficient=other)]) elif isinstance(other, PauliString): @@ -767,6 +769,8 @@ def __rsub__(self, other): return -self.__sub__(other) def __isub__(self, other): + if isinstance(other, raw_types.Operation): + other = pauli_string._try_interpret_as_pauli_string(other) if isinstance(other, numbers.Complex): other = PauliSum.from_pauli_strings([PauliString(coefficient=other)]) elif isinstance(other, PauliString): diff --git a/cirq-core/cirq/ops/linear_combinations_test.py b/cirq-core/cirq/ops/linear_combinations_test.py index da515ee3859..dae9fe4c9ab 100644 --- a/cirq-core/cirq/ops/linear_combinations_test.py +++ b/cirq-core/cirq/ops/linear_combinations_test.py @@ -975,10 +975,16 @@ def test_paulisum_validation() -> None: ps += cirq.I(cirq.LineQubit(0)) assert ps == cirq.PauliSum(cirq.LinearDict({frozenset(): complex(1)})) + ps = cirq.I(cirq.LineQubit(0)) + cirq.PauliSum() + assert ps == cirq.PauliSum(cirq.LinearDict({frozenset(): complex(1)})) + ps = cirq.PauliSum() ps -= cirq.I(cirq.LineQubit(0)) assert ps == cirq.PauliSum(cirq.LinearDict({frozenset(): complex(-1)})) + ps = cirq.I(cirq.LineQubit(0)) - cirq.PauliSum() + assert ps == cirq.PauliSum(cirq.LinearDict({frozenset(): complex(1)})) + def test_add_number_paulisum() -> None: q = cirq.LineQubit.range(2) @@ -1008,6 +1014,22 @@ def test_add_number_paulistring() -> None: == cirq.PauliSum.from_pauli_strings([cirq.PauliString() * 2, cirq.PauliString({a: cirq.X})]) ) + assert ( + cirq.X(a) - 2 + == -2 + cirq.X(a) + == cirq.PauliSum.from_pauli_strings( + [cirq.PauliString() * -2, cirq.PauliString({a: cirq.X})] + ) + ) + + assert ( + 2 - cirq.X(a) + == 2 + -cirq.X(a) + == cirq.PauliSum.from_pauli_strings( + [cirq.PauliString() * 2, -cirq.PauliString({a: cirq.X})] + ) + ) + def test_pauli_sum_formatting() -> None: q = cirq.LineQubit.range(2)