Skip to content

Commit 018e541

Browse files
committed
Add analytic diff on PhasedXPowGate
1 parent 2ba9d73 commit 018e541

File tree

4 files changed

+244
-131
lines changed

4 files changed

+244
-131
lines changed

tensorflow_quantum/core/ops/math_ops/inner_product_hessian_test.py

Lines changed: 32 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
cirq.FSimGate,
4848
]
4949

50-
_ATOL_FOR_COMPLEX_GATE = 1e-1
50+
_ATOL_FOR_COMPLEX_GATE = 1e-2
5151
_COMPLEX_GATES = [
5252
cirq.PhasedXPowGate,
5353
]
@@ -72,11 +72,8 @@ def get_gate(gate, symbol_names, qubits):
7272

7373
def get_shifted_resolved_circuit(circuit, name_j, name_k, dx_j, dx_k, resolver):
7474
new_resolver = copy.deepcopy(resolver)
75-
if name_j == name_k:
76-
new_resolver.param_dict[name_j] += (dx_j + dx_k)
77-
else:
78-
new_resolver.param_dict[name_j] += dx_j
79-
new_resolver.param_dict[name_k] += dx_k
75+
new_resolver.param_dict[name_j] += dx_j
76+
new_resolver.param_dict[name_k] += dx_k
8077
return cirq.resolve_parameters(circuit, new_resolver)
8178

8279

@@ -91,8 +88,7 @@ def get_finite_difference_hessian(circuit, name_j, name_k, resolver):
9188
final_circuit_pm = get_shifted_resolved_circuit(
9289
circuit, name_j, name_k, dx, -dx, resolver)
9390
final_circuit_mm = get_shifted_resolved_circuit(
94-
circuit, name_j, name_k, -dx, -dx,
95-
resolver)
91+
circuit, name_j, name_k, -dx, -dx, resolver)
9692
final_wf_pp = inv_square_two_dx * cirq.final_state_vector(final_circuit_pp)
9793
final_wf_mp = inv_square_two_dx * cirq.final_state_vector(final_circuit_mp)
9894
final_wf_pm = inv_square_two_dx * cirq.final_state_vector(final_circuit_pm)
@@ -312,21 +308,21 @@ class InnerProductAdjHessianTest(tf.test.TestCase, parameterized.TestCase):
312308
'batch_size': 1,
313309
'inner_dim_size': 5
314310
},
315-
# {
316-
# 'n_qubits': 5,
317-
# 'batch_size': 10,
318-
# 'inner_dim_size': 1
319-
# },
320-
# {
321-
# 'n_qubits': 10,
322-
# 'batch_size': 10,
323-
# 'inner_dim_size': 2
324-
# },
325-
# {
326-
# 'n_qubits': 5,
327-
# 'batch_size': 10,
328-
# 'inner_dim_size': 5
329-
# },
311+
{
312+
'n_qubits': 5,
313+
'batch_size': 10,
314+
'inner_dim_size': 1
315+
},
316+
{
317+
'n_qubits': 10,
318+
'batch_size': 10,
319+
'inner_dim_size': 2
320+
},
321+
{
322+
'n_qubits': 5,
323+
'batch_size': 10,
324+
'inner_dim_size': 5
325+
},
330326
])
331327
def correctness_with_symbols(self, n_qubits, batch_size,
332328
inner_dim_size):
@@ -336,7 +332,7 @@ def correctness_with_symbols(self, n_qubits, batch_size,
336332
qubits = cirq.GridQubit.rect(1, n_qubits)
337333
circuit_batch, resolver_batch = \
338334
util.random_symbol_circuit_resolver_batch(
339-
qubits, symbol_names, batch_size, n_moments=2)
335+
qubits, symbol_names, batch_size)
340336
print(circuit_batch)
341337

342338
other_batch = [
@@ -494,33 +490,15 @@ class InnerProductHessianOnGates(tf.test.TestCase, parameterized.TestCase):
494490

495491
@parameterized.parameters([
496492
{
497-
'gate': [cirq.XPowGate],
498-
'symbol_names': ['alpha','beta','gamma']# names
499-
}]) # for gate in _ONE_EIGEN_GATES + _TWO_EIGEN_GATES for names in _SYMBOL_NAMES])
493+
'gate': gate,
494+
'symbol_names': names
495+
} for gate in _ONE_EIGEN_GATES + _TWO_EIGEN_GATES
496+
for names in _SYMBOL_NAMES])
500497
def test_correctness_one_qubit_gate_with_symbols(self, gate, symbol_names):
501498
"""Tests that inner_product works with symbols."""
502499
n_params = len(symbol_names)
503-
qubits = cirq.GridQubit.rect(1, 5) # 2 if gate in _TWO_EIGEN_GATES else 1)
504-
# circuit_batch = [cirq.Circuit(get_gate(gate, symbol_names, qubits))]
505-
circuit_batch = [cirq.Circuit([
506-
cirq.Moment(
507-
(cirq.H**sympy.Mul(sympy.Float('0.96078327350981163', precision=53), sympy.Symbol('gamma'))).on(cirq.GridQubit(0, 0)),
508-
(cirq.Y**sympy.Mul(sympy.Float('0.73193316007366105', precision=53), sympy.Symbol('alpha'))).on(cirq.GridQubit(0, 3)),
509-
),
510-
cirq.Moment(
511-
cirq.Y(cirq.GridQubit(0, 1)),
512-
cirq.FSimGate(theta=0.123, phi=0.456).on(cirq.GridQubit(0, 4), cirq.GridQubit(0, 3)),
513-
cirq.PhasedXPowGate(phase_exponent=0.123).on(cirq.GridQubit(0, 0)),
514-
),
515-
cirq.Moment(
516-
cirq.Y(cirq.GridQubit(0, 4)),
517-
cirq.FSimGate(theta=0.123, phi=0.456).on(cirq.GridQubit(0, 2), cirq.GridQubit(0, 3)),
518-
),
519-
cirq.Moment(
520-
(cirq.H**sympy.Symbol('beta')).on(cirq.GridQubit(0, 0)),
521-
),
522-
])]
523-
500+
qubits = cirq.GridQubit.rect(1, 2 if gate in _TWO_EIGEN_GATES else 1)
501+
circuit_batch = [cirq.Circuit(get_gate(gate, symbol_names, qubits))]
524502
resolver_batch = [cirq.ParamResolver({name: 0.123 for name in symbol_names})]
525503

526504
symbol_values_array = np.array(
@@ -560,15 +538,12 @@ def test_correctness_one_qubit_gate_with_symbols(self, gate, symbol_names):
560538
else:
561539
weighted_internal_wf += internal_wf
562540
for j, name_j in enumerate(symbol_names):
563-
out_arr[i][j][j] = inner_product_op._inner_product_grad(
564-
programs, symbol_names_tensor, symbol_values, other_programs,
565-
other_programs_coeffs)[i][j]
566-
# for k, name_k in enumerate(symbol_names):
567-
# final_wf_grad = get_finite_difference_hessian(
568-
# circuit_batch[i], name_j, name_k, resolver)
569-
# out_arr[i][j][k] += (
570-
# programs_coeffs[i] *
571-
# np.vdot(final_wf_grad, weighted_internal_wf))
541+
for k, name_k in enumerate(symbol_names):
542+
final_wf_grad = get_finite_difference_hessian(
543+
circuit_batch[i], name_j, name_k, resolver)
544+
out_arr[i][j][k] += (
545+
programs_coeffs[i] *
546+
np.vdot(final_wf_grad, weighted_internal_wf))
572547

573548
# Elapsed time should be less than 5% of cirq version.
574549
# (at least 20x speedup)

tensorflow_quantum/core/ops/math_ops/tfq_inner_product_hessian.cc

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -312,21 +312,6 @@ class TfqInnerProductHessianOp : public tensorflow::OpKernel {
312312
std::cout << ">>>>>>... " << k << "th gradient gate is applied" << std::endl;
313313
qsim::ApplyGate(sim, hessian_gates[i][l - 1].grad_gates[k], scratch2);
314314

315-
auto ptr = scratch2.get();
316-
auto ptr_size = 2 << scratch2.num_qubits();
317-
std::cout << "Statevector" << std::endl;
318-
for (int i = 0; i < ptr_size; i++) {
319-
std::cout << ptr[i] << ",";
320-
}
321-
std::cout << std::endl;
322-
323-
ptr = scratch.get();
324-
ptr_size = 2 << scratch.num_qubits();
325-
std::cout << "Other Statevector" << std::endl;
326-
for (int i = 0; i < ptr_size; i++) {
327-
std::cout << ptr[i] << ",";
328-
}
329-
std::cout << std::endl;
330315
// don't need not-found check since this is done upstream already.
331316
auto symbol = hessian_gates[i][l - 1].params[k];
332317
std::cout << ">>>>>>... " << k << "th symbol = " << symbol << std::endl;
@@ -389,7 +374,7 @@ class TfqInnerProductHessianOp : public tensorflow::OpKernel {
389374
other_fused_circuits[i], sim, ss,
390375
scratch2, scratch);
391376
// now sv is |psi>
392-
// scratch contains sum_j other_programs_coeffs[i][j]*|phi[i][j]>
377+
// other_sv contains sum_j other_programs_coeffs[i][j]*|phi[i][j]>
393378
// Start adjoint differentiation on two gates
394379
// m is the index for the first gate
395380
std::cout << ">>> Start two gates hessian" << std::endl;
@@ -414,30 +399,28 @@ class TfqInnerProductHessianOp : public tensorflow::OpKernel {
414399
mask_m |= uint64_t{1} << control_loc;
415400
cbits_m |= ((cur_gate_m.cmask >> k) & 1) << control_loc;
416401
}
402+
403+
ss.Copy(scratch, scratch4);
404+
ss.Copy(sv, scratch2);
417405
for (std::vector<QsimGate>::size_type p = 0;
418406
p < gradient_gates[i][m - 1].grad_gates.size(); p++) {
419407
// Copy sv onto scratch2 in anticipation of the first non-unitary
420408
// "gradient gate".
421-
ss.Copy(sv, scratch2);
422409
if (!cur_gate_m.controlled_by.empty()) {
423410
// Gradient of controlled gates puts zeros on diagonal which is
424411
// the same as collapsing the state and then applying the
425412
// non-controlled version of the gradient gate.
426-
ss.BulkSetAmpl(scratch2, mask_m, cbits_m, 0, 0, true);
413+
ss.BulkSetAmpl(scratch4, mask_m, cbits_m, 0, 0, true);
427414
}
428415
std::cout << ">>>>>>(1)... p=" << p << "th gradient gate is applied" << std::endl;
429-
qsim::ApplyGate(sim, gradient_gates[i][m - 1].grad_gates[p],
430-
scratch2);
416+
qsim::ApplyGateDagger(sim, gradient_gates[i][m - 1].grad_gates[p],
417+
scratch4);
431418

432419
// don't need not-found check since this is done upstream already.
433420
const auto it = maps[i].find(gradient_gates[i][m - 1].params[p]);
434421
std::cout << ">>>>>>(1)... p=" << p << "th symbol = " << gradient_gates[i][m - 1].params[p] << std::endl;
435422
const int loc_m = it->second.first;
436423

437-
// scratch2 is now (d/dsymbol[p])|psi>
438-
// Copy scratch onto scratch4.
439-
ss.Copy(scratch, scratch4);
440-
// ApplyGateDagger(sim, cur_gate_m, scratch4);
441424
// n is the index for the second gate
442425
for (int n = m - 1; n >= 0; n--) {
443426
std::cout << ">>>>>>---(2) " << n << "th partial fused circuit is applied" << std::endl;
@@ -455,6 +438,7 @@ class TfqInnerProductHessianOp : public tensorflow::OpKernel {
455438
}
456439

457440
// Hit a parameterized gate.
441+
std::cout << "n-th gate index = " << gradient_gates[i][n - 1].index << std::endl;
458442
auto cur_gate_n =
459443
qsim_circuits[i].gates[gradient_gates[i][n - 1].index];
460444
ApplyGateDagger(sim, cur_gate_n, scratch2);
@@ -485,6 +469,21 @@ class TfqInnerProductHessianOp : public tensorflow::OpKernel {
485469
qsim::ApplyGate(sim, gradient_gates[i][n - 1].grad_gates[q],
486470
scratch3);
487471

472+
auto ptr = scratch3.get();
473+
auto ptr_size = 2 << scratch3.num_qubits();
474+
std::cout << "Statevector" << std::endl;
475+
for (int i = 0; i < ptr_size; i++) {
476+
std::cout << ptr[i] << ",";
477+
}
478+
std::cout << std::endl;
479+
480+
ptr = scratch4.get();
481+
ptr_size = 2 << scratch4.num_qubits();
482+
std::cout << "Other Statevector" << std::endl;
483+
for (int i = 0; i < ptr_size; i++) {
484+
std::cout << ptr[i] << ",";
485+
}
486+
std::cout << std::endl;
488487
// don't need not-found check since this is done upstream already.
489488
const auto it = maps[i].find(gradient_gates[i][n - 1].params[q]);
490489
std::cout << ">>>>>>---(2)... q=" << q << "th symbol = " << gradient_gates[i][n - 1].params[q] << std::endl;
@@ -543,7 +542,6 @@ class TfqInnerProductHessianOp : public tensorflow::OpKernel {
543542
Simulator sim = Simulator(tfq_for);
544543
StateSpace ss = StateSpace(tfq_for);
545544
auto sv = ss.Create(largest_nq);
546-
auto sv_adj = ss.Create(largest_nq);
547545
auto scratch = ss.Create(largest_nq);
548546
auto scratch2 = ss.Create(largest_nq);
549547
auto scratch3 = ss.Create(largest_nq);
@@ -560,7 +558,6 @@ class TfqInnerProductHessianOp : public tensorflow::OpKernel {
560558
if (nq > largest_nq) {
561559
largest_nq = nq;
562560
sv = ss.Create(largest_nq);
563-
sv_adj = ss.Create(largest_nq);
564561
scratch = ss.Create(largest_nq);
565562
scratch2 = ss.Create(largest_nq);
566563
}

tensorflow_quantum/core/src/adj_hessian_util.cc

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -389,21 +389,23 @@ void PopulateHessianPhasedXPhasedExponent(const std::string& symbol,
389389
float gs, GradientOfGate* grad) {
390390
grad->params.push_back(symbol);
391391
grad->index = location;
392-
auto left = qsim::Cirq::PhasedXPowGate<float>::Create(
393-
0, qid, (pexp + _HESS_EPS) * pexp_s, exp * exp_s, gs);
394-
auto center = qsim::Cirq::PhasedXPowGate<float>::Create(0, qid, pexp * pexp_s,
395-
exp * exp_s, gs);
396-
auto right = qsim::Cirq::PhasedXPowGate<float>::Create(
397-
0, qid, (pexp - _HESS_EPS) * pexp_s, exp * exp_s, gs);
398-
// Due to precision issue, multiply weights first.
399-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, left.matrix);
400-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, right.matrix);
401-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, center.matrix);
402-
Matrix2Add(right.matrix,
403-
left.matrix); // left's entries have right added.
404-
qsim::MatrixScalarMultiply(2.0, center.matrix);
405-
Matrix2Diff(center.matrix,
406-
left.matrix); // left's entries have center subtracted.
392+
// auto left = qsim::Cirq::PhasedXPowGate<float>::Create(
393+
// 0, qid, (pexp + _HESS_EPS) * pexp_s, exp * exp_s, gs);
394+
// auto center = qsim::Cirq::PhasedXPowGate<float>::Create(0, qid, pexp * pexp_s,
395+
// exp * exp_s, gs);
396+
// auto right = qsim::Cirq::PhasedXPowGate<float>::Create(
397+
// 0, qid, (pexp - _HESS_EPS) * pexp_s, exp * exp_s, gs);
398+
// // Due to precision issue, multiply weights first.
399+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, left.matrix);
400+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, right.matrix);
401+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, center.matrix);
402+
// Matrix2Add(right.matrix,
403+
// left.matrix); // left's entries have right added.
404+
// qsim::MatrixScalarMultiply(2.0, center.matrix);
405+
// Matrix2Diff(center.matrix,
406+
// left.matrix); // left's entries have center subtracted.
407+
auto left = D2PhasedExponentPhasedXPowGate<float>::Create(
408+
0, qid, pexp, pexp_s, exp*exp_s, gs);
407409
grad->grad_gates.push_back(left);
408410
}
409411

@@ -414,21 +416,23 @@ void PopulateHessianPhasedXExponent(const std::string& symbol,
414416
GradientOfGate* grad) {
415417
grad->params.push_back(symbol);
416418
grad->index = location;
417-
auto left = qsim::Cirq::PhasedXPowGate<float>::Create(
418-
0, qid, pexp * pexp_s, (exp + _HESS_EPS) * exp_s, gs);
419-
auto center = qsim::Cirq::PhasedXPowGate<float>::Create(0, qid, pexp * pexp_s,
420-
exp * exp_s, gs);
421-
auto right = qsim::Cirq::PhasedXPowGate<float>::Create(
422-
0, qid, pexp * pexp_s, (exp - _HESS_EPS) * exp_s, gs);
423-
// Due to precision issue, multiply weights first.
424-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, left.matrix);
425-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, right.matrix);
426-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, center.matrix);
427-
Matrix2Add(right.matrix,
428-
left.matrix); // left's entries have right added.
429-
qsim::MatrixScalarMultiply(2.0, center.matrix);
430-
Matrix2Diff(center.matrix,
431-
left.matrix); // left's entries have center subtracted.
419+
// auto left = qsim::Cirq::PhasedXPowGate<float>::Create(
420+
// 0, qid, pexp * pexp_s, (exp + _HESS_EPS) * exp_s, gs);
421+
// auto center = qsim::Cirq::PhasedXPowGate<float>::Create(0, qid, pexp * pexp_s,
422+
// exp * exp_s, gs);
423+
// auto right = qsim::Cirq::PhasedXPowGate<float>::Create(
424+
// 0, qid, pexp * pexp_s, (exp - _HESS_EPS) * exp_s, gs);
425+
// // Due to precision issue, multiply weights first.
426+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, left.matrix);
427+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, right.matrix);
428+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, center.matrix);
429+
// Matrix2Add(right.matrix,
430+
// left.matrix); // left's entries have right added.
431+
// qsim::MatrixScalarMultiply(2.0, center.matrix);
432+
// Matrix2Diff(center.matrix,
433+
// left.matrix); // left's entries have center subtracted.
434+
auto left = D2ExponentPhasedXPowGate<float>::Create(
435+
0, qid, pexp * pexp_s, exp, exp_s, gs);
432436
grad->grad_gates.push_back(left);
433437
}
434438

@@ -437,24 +441,26 @@ void PopulateCrossTermPhasedXPhasedExponentExponent(
437441
float exp, float exp_s, float gs, GradientOfGate* grad) {
438442
grad->params.push_back(kUsePrevTwoSymbols);
439443
grad->index = location;
440-
auto left = qsim::Cirq::PhasedXPowGate<float>::Create(
441-
0, qid, (pexp + _GRAD_EPS) * pexp_s, (exp + _GRAD_EPS) * exp_s, gs);
442-
auto left_center = qsim::Cirq::PhasedXPowGate<float>::Create(
443-
0, qid, (pexp + _GRAD_EPS) * pexp_s, (exp - _GRAD_EPS) * exp_s, gs);
444-
auto right_center = qsim::Cirq::PhasedXPowGate<float>::Create(
445-
0, qid, (pexp - _GRAD_EPS) * pexp_s, (exp + _GRAD_EPS) * exp_s, gs);
446-
auto right = qsim::Cirq::PhasedXPowGate<float>::Create(
447-
0, qid, (pexp - _GRAD_EPS) * pexp_s, (exp - _GRAD_EPS) * exp_s, gs);
448-
// Due to precision issue, multiply weights first.
449-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, left.matrix);
450-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, right.matrix);
451-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, left_center.matrix);
452-
qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, right_center.matrix);
453-
Matrix2Add(right.matrix,
454-
left.matrix); // left's entries have right added.
455-
Matrix2Add(right_center.matrix, left_center.matrix);
456-
Matrix2Diff(left_center.matrix,
457-
left.matrix); // left's entries have left_center subtracted.
444+
// auto left = qsim::Cirq::PhasedXPowGate<float>::Create(
445+
// 0, qid, (pexp + _GRAD_EPS) * pexp_s, (exp + _GRAD_EPS) * exp_s, gs);
446+
// auto left_center = qsim::Cirq::PhasedXPowGate<float>::Create(
447+
// 0, qid, (pexp + _GRAD_EPS) * pexp_s, (exp - _GRAD_EPS) * exp_s, gs);
448+
// auto right_center = qsim::Cirq::PhasedXPowGate<float>::Create(
449+
// 0, qid, (pexp - _GRAD_EPS) * pexp_s, (exp + _GRAD_EPS) * exp_s, gs);
450+
// auto right = qsim::Cirq::PhasedXPowGate<float>::Create(
451+
// 0, qid, (pexp - _GRAD_EPS) * pexp_s, (exp - _GRAD_EPS) * exp_s, gs);
452+
// // Due to precision issue, multiply weights first.
453+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, left.matrix);
454+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, right.matrix);
455+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, left_center.matrix);
456+
// qsim::MatrixScalarMultiply(_INVERSE_HESS_EPS_SQUARE, right_center.matrix);
457+
// Matrix2Add(right.matrix,
458+
// left.matrix); // left's entries have right added.
459+
// Matrix2Add(right_center.matrix, left_center.matrix);
460+
// Matrix2Diff(left_center.matrix,
461+
// left.matrix); // left's entries have left_center subtracted.
462+
auto left = DPhasedExponentDExponentPhasedXPowGate<float>::Create(
463+
0, qid, pexp, pexp_s, exp, exp_s, gs);
458464
grad->grad_gates.push_back(left);
459465
}
460466

0 commit comments

Comments
 (0)