Skip to content

Commit 96c2004

Browse files
daxfohlpavoljuhas
andauthored
[bugfix] Prevent GateOperation.add from superseding Circuit.radd (#7707)
Per issue #7706, the change in #7651 caused `Circuit.__radd__`, which allows `+` to prepend operations to circuits, e.g. `H(q) + Circuit(...) == Circuit(H(q), ...)`, to stop working. The root cause is that the change added a `GateOperation.__add__`, which thus supersedes `Circuit.__radd__` and changes the behavior of `+` in such cases. The fix for this is to modify the new `GateOperation.__add__` (and `GateOperation.__sub__`, for symmetry) so that it only applies in cases where interpreting the sum as a `PauliString` is reasonable. To do this, first, I updated `GateOperation.__add__` with a check, `if not isinstance(other, (Operation, Number)): return NotImplemented`, since ops and numbers are the only things that can be interpreted as a `PauliString`. This change prevents `GateOperation.__add__` from interfering with `Circuit.__radd__`, and constitutes the crux of the fix. That change in isolation breaks `GateOperation + PauliSum` because the latter isn't an `Operation`. The simplest fix for *this* would be to add `PauliSum` to the new isinstance check, e.g. `isinstance(other, (Operation, Number, PauliSum))`, However that approach adds `PauliSum` as a dependency of `GateOperation`, which seems unnatural. So instead, I added logic in `PauliSum.__add__` to convert ops to `PauliString` first. That approach doesn't create any new dependencies, and is also more robust because it works for *all* operations, not just GateOperations. (Note, an even simpler fix for the entire issue would have been, instead of allowlisting `(Operation, Number)` in `GateOperation.__add__`, to denylist `Circuit` instead, e.g. `if isinstance(other, Circuit): return NotImplemented`. That would have fixed the issue and not required any special change for `PauliSum`. But similarly, I chose against that route because of the unnatural dependency it would create, as well as the fact that that fix wouldn't help if any third-party classes used `radd` on operations in the same way `Circuit` does.) Finally, I updated the `Circuit.radd` unit test to include assertions that would have broken under the old code. The existing test used `X` gates everywhere, which didn't fail when prepending because of misc internal logic for Pauli gates. So I parameterized the test on `gate`, and added a non-Pauli gate `H` to the parameters, which will now fail if a similar regression is made in the future. I updated some tests in linear_combinations to test `__sub__` more thoroughly as well, for symmetry. Key takeaway: be careful when implementing `__add__` for an existing class, and limit the scope of `other` to which it can apply, as otherwise the change can break `__radd__` functionality in unrelated classes. Fixes #7706 --------- Co-authored-by: Pavol Juhas <pavol.juhas@gmail.com>
1 parent 57f17ce commit 96c2004

File tree

4 files changed

+43
-11
lines changed

4 files changed

+43
-11
lines changed

cirq-core/cirq/circuits/circuit_test.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -356,20 +356,21 @@ def test_add_op_tree(circuit_cls) -> None:
356356

357357

358358
@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
359-
def test_radd_op_tree(circuit_cls) -> None:
359+
@pytest.mark.parametrize('gate', [cirq.X, cirq.H])
360+
def test_radd_op_tree(circuit_cls, gate) -> None:
360361
a = cirq.NamedQubit('a')
361362
b = cirq.NamedQubit('b')
362363

363364
c = circuit_cls()
364-
assert [cirq.X(a), cirq.Y(b)] + c == circuit_cls([cirq.Moment([cirq.X(a), cirq.Y(b)])])
365+
assert [gate(a), cirq.Y(b)] + c == circuit_cls([cirq.Moment([gate(a), cirq.Y(b)])])
365366

366-
assert cirq.X(a) + c == circuit_cls(cirq.X(a))
367-
assert [cirq.X(a)] + c == circuit_cls(cirq.X(a))
368-
assert [[[cirq.X(a)], []]] + c == circuit_cls(cirq.X(a))
369-
assert (cirq.X(a),) + c == circuit_cls(cirq.X(a))
370-
assert (cirq.X(a) for _ in range(1)) + c == circuit_cls(cirq.X(a))
367+
assert gate(a) + c == circuit_cls(gate(a))
368+
assert [gate(a)] + c == circuit_cls(gate(a))
369+
assert [[[gate(a)], []]] + c == circuit_cls(gate(a))
370+
assert (gate(a),) + c == circuit_cls(gate(a))
371+
assert (gate(a) for _ in range(1)) + c == circuit_cls(gate(a))
371372
with pytest.raises(AttributeError):
372-
_ = cirq.X + c
373+
_ = gate + c
373374
with pytest.raises(TypeError):
374375
_ = 0 + c
375376

@@ -380,9 +381,9 @@ def test_radd_op_tree(circuit_cls) -> None:
380381
else:
381382
d = cirq.Circuit()
382383
d.append(cirq.Y(b))
383-
assert [cirq.X(a)] + d == circuit_cls([cirq.Moment([cirq.X(a)]), cirq.Moment([cirq.Y(b)])])
384-
assert cirq.Moment([cirq.X(a)]) + d == circuit_cls(
385-
[cirq.Moment([cirq.X(a)]), cirq.Moment([cirq.Y(b)])]
384+
assert [gate(a)] + d == circuit_cls([cirq.Moment([gate(a)]), cirq.Moment([cirq.Y(b)])])
385+
assert cirq.Moment([gate(a)]) + d == circuit_cls(
386+
[cirq.Moment([gate(a)]), cirq.Moment([cirq.Y(b)])]
386387
)
387388

388389

cirq-core/cirq/ops/gate_operation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import numbers
1920
import re
2021
import warnings
2122
from collections.abc import Collection, Mapping, Sequence, Set
@@ -323,12 +324,16 @@ def __rmul__(self, other: Any) -> Any:
323324
return self.gate._rmul_with_qubits(self._qubits, other)
324325

325326
def __add__(self, other):
327+
if not isinstance(other, (ops.Operation, numbers.Number)):
328+
return NotImplemented
326329
return 1 * self + other
327330

328331
def __radd__(self, other):
329332
return other + 1 * self
330333

331334
def __sub__(self, other):
335+
if not isinstance(other, (ops.Operation, numbers.Number)):
336+
return NotImplemented
332337
return 1 * self - other
333338

334339
def __rsub__(self, other):

cirq-core/cirq/ops/linear_combinations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,8 @@ def __len__(self) -> int:
743743
return len(self._linear_dict)
744744

745745
def __iadd__(self, other):
746+
if isinstance(other, raw_types.Operation):
747+
other = pauli_string._try_interpret_as_pauli_string(other)
746748
if isinstance(other, numbers.Complex):
747749
other = PauliSum.from_pauli_strings([PauliString(coefficient=other)])
748750
elif isinstance(other, PauliString):
@@ -767,6 +769,8 @@ def __rsub__(self, other):
767769
return -self.__sub__(other)
768770

769771
def __isub__(self, other):
772+
if isinstance(other, raw_types.Operation):
773+
other = pauli_string._try_interpret_as_pauli_string(other)
770774
if isinstance(other, numbers.Complex):
771775
other = PauliSum.from_pauli_strings([PauliString(coefficient=other)])
772776
elif isinstance(other, PauliString):

cirq-core/cirq/ops/linear_combinations_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,10 +975,16 @@ def test_paulisum_validation() -> None:
975975
ps += cirq.I(cirq.LineQubit(0))
976976
assert ps == cirq.PauliSum(cirq.LinearDict({frozenset(): complex(1)}))
977977

978+
ps = cirq.I(cirq.LineQubit(0)) + cirq.PauliSum()
979+
assert ps == cirq.PauliSum(cirq.LinearDict({frozenset(): complex(1)}))
980+
978981
ps = cirq.PauliSum()
979982
ps -= cirq.I(cirq.LineQubit(0))
980983
assert ps == cirq.PauliSum(cirq.LinearDict({frozenset(): complex(-1)}))
981984

985+
ps = cirq.I(cirq.LineQubit(0)) - cirq.PauliSum()
986+
assert ps == cirq.PauliSum(cirq.LinearDict({frozenset(): complex(1)}))
987+
982988

983989
def test_add_number_paulisum() -> None:
984990
q = cirq.LineQubit.range(2)
@@ -1008,6 +1014,22 @@ def test_add_number_paulistring() -> None:
10081014
== cirq.PauliSum.from_pauli_strings([cirq.PauliString() * 2, cirq.PauliString({a: cirq.X})])
10091015
)
10101016

1017+
assert (
1018+
cirq.X(a) - 2
1019+
== -2 + cirq.X(a)
1020+
== cirq.PauliSum.from_pauli_strings(
1021+
[cirq.PauliString() * -2, cirq.PauliString({a: cirq.X})]
1022+
)
1023+
)
1024+
1025+
assert (
1026+
2 - cirq.X(a)
1027+
== 2 + -cirq.X(a)
1028+
== cirq.PauliSum.from_pauli_strings(
1029+
[cirq.PauliString() * 2, -cirq.PauliString({a: cirq.X})]
1030+
)
1031+
)
1032+
10111033

10121034
def test_pauli_sum_formatting() -> None:
10131035
q = cirq.LineQubit.range(2)

0 commit comments

Comments
 (0)