Skip to content

Commit 2ba9d73

Browse files
committed
Fix adjoint hessian bug
1 parent ee12f8d commit 2ba9d73

File tree

13 files changed

+2310
-19
lines changed

13 files changed

+2310
-19
lines changed

tensorflow_quantum/core/ops/math_ops/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_binary(
1717
srcs = [
1818
"tfq_inner_product.cc",
1919
"tfq_inner_product_grad.cc",
20+
"tfq_inner_product_hessian.cc",
2021
],
2122
copts = select({
2223
":windows": [
@@ -62,6 +63,7 @@ cc_binary(
6263
# cirq cc proto
6364
"//tensorflow_quantum/core/ops:parse_context",
6465
"//tensorflow_quantum/core/ops:tfq_simulate_utils",
66+
"//tensorflow_quantum/core/src:adj_hessian_util",
6567
"//tensorflow_quantum/core/src:adj_util",
6668
"//tensorflow_quantum/core/src:circuit_parser_qsim",
6769
"//tensorflow_quantum/core/src:util_qsim",
@@ -100,3 +102,13 @@ py_test(
100102
"//tensorflow_quantum/python:util",
101103
],
102104
)
105+
106+
py_test(
107+
name = "inner_product_hessian_test",
108+
srcs = ["inner_product_hessian_test.py"],
109+
python_version = "PY3",
110+
deps = [
111+
":inner_product_op_py",
112+
"//tensorflow_quantum/python:util",
113+
],
114+
)

tensorflow_quantum/core/ops/math_ops/inner_product_grad_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class InnerProductAdjGradTest(tf.test.TestCase, parameterized.TestCase):
2727
"""Tests tfq_inner_product_grad."""
2828

2929
def test_inner_product_grad_inputs(self):
30-
"""Makes sure that inner_product_adj_grad fails on bad inputs."""
30+
"""Makes sure that inner_product_grad fails on bad inputs."""
3131
n_qubits = 5
3232
batch_size = 5
3333
n_other_programs = 3
@@ -232,7 +232,7 @@ def test_inner_product_grad_inputs(self):
232232
])
233233
def test_correctness_with_symbols(self, n_qubits, batch_size,
234234
inner_dim_size):
235-
"""Tests that inner_product works with symbols."""
235+
"""Tests that inner_product_grad works with symbols."""
236236
symbol_names = ['alpha', 'beta', 'gamma']
237237
n_params = len(symbol_names)
238238
qubits = cirq.GridQubit.rect(1, n_qubits)
@@ -242,7 +242,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
242242

243243
other_batch = [
244244
util.random_circuit_resolver_batch(qubits, inner_dim_size)[0]
245-
for i in range(batch_size)
245+
for _ in range(batch_size)
246246
]
247247

248248
symbol_values_array = np.array(
@@ -312,15 +312,15 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
312312
])
313313
def test_correctness_without_symbols(self, n_qubits, batch_size,
314314
inner_dim_size):
315-
"""Tests that inner_product_adj_grad works without symbols."""
315+
"""Tests that inner_product_grad works without symbols."""
316316
qubits = cirq.GridQubit.rect(1, n_qubits)
317317
circuit_batch, _ = \
318318
util.random_circuit_resolver_batch(
319319
qubits, batch_size)
320320

321321
other_batch = [
322322
util.random_circuit_resolver_batch(qubits, inner_dim_size)[0]
323-
for i in range(batch_size)
323+
for _ in range(batch_size)
324324
]
325325

326326
programs = util.convert_to_tensor(circuit_batch)

0 commit comments

Comments
 (0)