Skip to content

Commit 91c4d8a

Browse files
merge conflict.
2 parents 63522df + 6e67c0e commit 91c4d8a

File tree

5 files changed

+522
-14
lines changed

5 files changed

+522
-14
lines changed

tensorflow_quantum/core/serialize/serializer.py

Lines changed: 180 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _optional_control_promote(gate, qubits_message, values_message):
154154
return DelayedAssignmentGate(gate, qbs, vals)
155155

156156

157+
# Channels.
157158
def _asymmetric_depolarize_serializer():
158159
"""Make standard serializer for asymmetric depolarization channel."""
159160
args = [
@@ -231,9 +232,174 @@ def _depolarize_channel_deserializer():
231232
args=args)
232233

233234

235+
def _gad_channel_serializer():
236+
"""Make standard serializer for GeneralizedAmplitudeDamping."""
237+
238+
args = [
239+
# cirq channels can't contain symbols.
240+
cirq.google.SerializingArg(serialized_name="p",
241+
serialized_type=float,
242+
op_getter=lambda x: x.gate.p),
243+
cirq.google.SerializingArg(serialized_name="gamma",
244+
serialized_type=float,
245+
op_getter=lambda x: x.gate.gamma),
246+
cirq.google.SerializingArg(serialized_name="control_qubits",
247+
serialized_type=str,
248+
op_getter=lambda x: ''),
249+
cirq.google.SerializingArg(serialized_name="control_values",
250+
serialized_type=str,
251+
op_getter=lambda x: '')
252+
]
253+
return cirq.google.GateOpSerializer(
254+
gate_type=cirq.GeneralizedAmplitudeDampingChannel,
255+
serialized_gate_id="GAD",
256+
args=args,
257+
can_serialize_predicate=_CONSTANT_TRUE)
258+
259+
260+
def _gad_channel_deserializer():
261+
"""Make standard deserializer for GeneralizedAmplitudeDamping."""
262+
263+
args = [
264+
cirq.google.DeserializingArg(serialized_name="p",
265+
constructor_arg_name="p"),
266+
cirq.google.DeserializingArg(serialized_name="gamma",
267+
constructor_arg_name="gamma")
268+
]
269+
return cirq.google.GateOpDeserializer(
270+
serialized_gate_id="GAD",
271+
gate_constructor=cirq.GeneralizedAmplitudeDampingChannel,
272+
args=args)
273+
274+
275+
def _amplitude_damp_channel_serializer():
276+
"""Make standard serializer for AmplitudeDamp channel."""
277+
278+
args = [
279+
# cirq channels can't contain symbols.
280+
cirq.google.SerializingArg(serialized_name="gamma",
281+
serialized_type=float,
282+
op_getter=lambda x: x.gate.gamma),
283+
cirq.google.SerializingArg(serialized_name="control_qubits",
284+
serialized_type=str,
285+
op_getter=lambda x: ''),
286+
cirq.google.SerializingArg(serialized_name="control_values",
287+
serialized_type=str,
288+
op_getter=lambda x: '')
289+
]
290+
return cirq.google.GateOpSerializer(gate_type=cirq.AmplitudeDampingChannel,
291+
serialized_gate_id="AD",
292+
args=args,
293+
can_serialize_predicate=_CONSTANT_TRUE)
294+
295+
296+
def _amplitude_damp_channel_deserializer():
297+
"""Make standard deserializer for depolarization channel."""
298+
299+
args = [
300+
cirq.google.DeserializingArg(serialized_name="gamma",
301+
constructor_arg_name="gamma")
302+
]
303+
return cirq.google.GateOpDeserializer(
304+
serialized_gate_id="AD",
305+
gate_constructor=cirq.AmplitudeDampingChannel,
306+
args=args)
307+
308+
309+
def _reset_channel_serializer():
310+
"""Make standard serializer for reset channel."""
311+
312+
args = [
313+
# cirq channels can't contain symbols.
314+
cirq.google.SerializingArg(serialized_name="control_qubits",
315+
serialized_type=str,
316+
op_getter=lambda x: ''),
317+
cirq.google.SerializingArg(serialized_name="control_values",
318+
serialized_type=str,
319+
op_getter=lambda x: '')
320+
]
321+
return cirq.google.GateOpSerializer(gate_type=cirq.ResetChannel,
322+
serialized_gate_id="RST",
323+
args=args,
324+
can_serialize_predicate=_CONSTANT_TRUE)
325+
326+
327+
def _reset_channel_deserializer():
328+
"""Make standard deserializer for reset channel."""
329+
330+
args = []
331+
return cirq.google.GateOpDeserializer(serialized_gate_id="RST",
332+
gate_constructor=cirq.ResetChannel,
333+
args=args)
334+
335+
336+
def _phase_damp_channel_serializer():
337+
"""Make standard serializer for PhaseDamp channel."""
338+
args = [
339+
# cirq channels can't contain symbols.
340+
cirq.google.SerializingArg(serialized_name="gamma",
341+
serialized_type=float,
342+
op_getter=lambda x: x.gate.gamma),
343+
cirq.google.SerializingArg(serialized_name="control_qubits",
344+
serialized_type=str,
345+
op_getter=lambda x: ''),
346+
cirq.google.SerializingArg(serialized_name="control_values",
347+
serialized_type=str,
348+
op_getter=lambda x: '')
349+
]
350+
return cirq.google.GateOpSerializer(gate_type=cirq.PhaseDampingChannel,
351+
serialized_gate_id="PD",
352+
args=args,
353+
can_serialize_predicate=_CONSTANT_TRUE)
354+
355+
356+
def _phase_damp_channel_deserializer():
357+
"""Make standard deserializer for PhaseDamp channel."""
358+
args = [
359+
cirq.google.DeserializingArg(serialized_name="gamma",
360+
constructor_arg_name="gamma")
361+
]
362+
return cirq.google.GateOpDeserializer(
363+
serialized_gate_id="PD",
364+
gate_constructor=cirq.PhaseDampingChannel,
365+
args=args)
366+
367+
368+
def _phase_flip_channel_serializer():
369+
"""Make standard serializer for PhaseFlip channel."""
370+
args = [
371+
# cirq channels can't contain symbols.
372+
cirq.google.SerializingArg(serialized_name="p",
373+
serialized_type=float,
374+
op_getter=lambda x: x.gate.p),
375+
cirq.google.SerializingArg(serialized_name="control_qubits",
376+
serialized_type=str,
377+
op_getter=lambda x: ''),
378+
cirq.google.SerializingArg(serialized_name="control_values",
379+
serialized_type=str,
380+
op_getter=lambda x: '')
381+
]
382+
return cirq.google.GateOpSerializer(gate_type=cirq.PhaseFlipChannel,
383+
serialized_gate_id="PF",
384+
args=args,
385+
can_serialize_predicate=_CONSTANT_TRUE)
386+
387+
388+
def _phase_flip_channel_deserializer():
389+
"""Make standard deserializer for PhaseFlip channel."""
390+
391+
args = [
392+
cirq.google.DeserializingArg(serialized_name="p",
393+
constructor_arg_name="p")
394+
]
395+
return cirq.google.GateOpDeserializer(
396+
serialized_gate_id="PF",
397+
gate_constructor=cirq.PhaseFlipChannel,
398+
args=args)
399+
400+
234401
def _bit_flip_channel_serializer():
235402
"""Make standard serializer for BitFlip channel."""
236-
237403
args = [
238404
# cirq channels can't contain symbols.
239405
cirq.google.SerializingArg(serialized_name="p",
@@ -254,7 +420,6 @@ def _bit_flip_channel_serializer():
254420

255421
def _bit_flip_channel_deserializer():
256422
"""Make standard deserializer for BitFlip channel."""
257-
258423
args = [
259424
cirq.google.DeserializingArg(serialized_name="p",
260425
constructor_arg_name="p")
@@ -264,6 +429,7 @@ def _bit_flip_channel_deserializer():
264429
args=args)
265430

266431

432+
# Gates.
267433
def _eigen_gate_serializer(gate_type, serialized_id):
268434
"""Make standard serializer for eigen gates."""
269435

@@ -562,11 +728,16 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
562728
_phased_eigen_gate_serializer(g, g_name)
563729
for g, g_name in PHASED_EIGEN_GATES_DICT.items()
564730
] + [
731+
_amplitude_damp_channel_serializer(),
565732
_asymmetric_depolarize_serializer(),
566733
_bit_flip_channel_serializer(),
567734
_depolarize_channel_serializer(),
568735
_fsim_gate_serializer(),
569-
_identity_gate_serializer()
736+
_gad_channel_serializer(),
737+
_identity_gate_serializer(),
738+
_phase_damp_channel_serializer(),
739+
_reset_channel_serializer(),
740+
_phase_flip_channel_serializer()
570741
]
571742

572743
DESERIALIZERS = [
@@ -576,11 +747,16 @@ def _scalar_combiner(exponent, global_shift, exponent_scalar,
576747
_phased_eigen_gate_deserializer(g, g_name)
577748
for g, g_name in PHASED_EIGEN_GATES_DICT.items()
578749
] + [
750+
_amplitude_damp_channel_deserializer(),
579751
_asymmetric_depolarize_deserializer(),
580752
_bit_flip_channel_deserializer(),
581753
_depolarize_channel_deserializer(),
582754
_fsim_gate_deserializer(),
583-
_identity_gate_deserializer()
755+
_gad_channel_deserializer(),
756+
_identity_gate_deserializer(),
757+
_phase_damp_channel_deserializer(),
758+
_reset_channel_deserializer(),
759+
_phase_flip_channel_deserializer()
584760
]
585761

586762
SERIALIZER = cirq.google.SerializableGateSet(gate_set_name="tfq_gate_set",

tensorflow_quantum/core/serialize/serializer_test.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,28 @@ def _get_noise_proto_pairs():
415415
_build_op_proto("ADP", ['p_x', 'p_y', 'p_z'], [0.1, 0.2, 0.3],
416416
['0_0'])),
417417

418+
# Generalized Amplitude damp.
419+
(cirq.Circuit(cirq.generalized_amplitude_damp(p=0.1, gamma=0.2)(q0)),
420+
_build_op_proto("GAD", ['p', 'gamma'], [0.1, 0.2], ['0_0'])),
421+
422+
# Amplitude damp.
423+
(cirq.Circuit(cirq.amplitude_damp(gamma=0.1)(q0)),
424+
_build_op_proto("AD", ['gamma'], [0.1], ['0_0'])),
425+
426+
# Reset.
427+
(cirq.Circuit(cirq.reset(q0)), _build_op_proto("RST", [], [], ['0_0'])),
428+
429+
# Phase damp.
430+
(cirq.Circuit(cirq.phase_damp(gamma=0.1)(q0)),
431+
_build_op_proto("PD", ['gamma'], [0.1], ['0_0'])),
432+
433+
# Phase flip.
434+
(cirq.Circuit(cirq.phase_flip(p=0.1)(q0)),
435+
_build_op_proto("PF", ['p'], [0.1], ['0_0'])),
436+
418437
# Bit flip.
419438
(cirq.Circuit(cirq.bit_flip(p=0.1)(q0)),
420-
_build_op_proto("BF", ['p'], [0.1], ['0_0'])),
439+
_build_op_proto("BF", ['p'], [0.1], ['0_0']))
421440
]
422441
return pairs
423442

tensorflow_quantum/core/src/circuit_parser_qsim.cc

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,103 @@ inline Status DepolarizingChannel(const Operation& op,
641641
return Status::OK();
642642
}
643643

644+
inline Status GADChannel(const Operation& op, const unsigned int num_qubits,
645+
const unsigned int time, NoisyQsimCircuit* ncircuit) {
646+
int q;
647+
bool unused;
648+
float p, gamma;
649+
Status u;
650+
unused = absl::SimpleAtoi(op.qubits(0).id(), &q);
651+
652+
u = ParseProtoArg(op, "p", {}, &p);
653+
if (!u.ok()) {
654+
return u;
655+
}
656+
u = ParseProtoArg(op, "gamma", {}, &gamma);
657+
if (!u.ok()) {
658+
return u;
659+
}
660+
661+
auto chan = qsim::Cirq::GeneralizedAmplitudeDampingChannel<float>::Create(
662+
time, num_qubits - q - 1, p, gamma);
663+
ncircuit->channels.push_back(chan);
664+
return Status::OK();
665+
}
666+
667+
inline Status ResetChannel(const Operation& op, const unsigned int num_qubits,
668+
const unsigned int time,
669+
NoisyQsimCircuit* ncircuit) {
670+
int q;
671+
bool unused;
672+
unused = absl::SimpleAtoi(op.qubits(0).id(), &q);
673+
674+
auto chan = qsim::Cirq::ResetChannel<float>::Create(time, num_qubits - q - 1);
675+
ncircuit->channels.push_back(chan);
676+
return Status::OK();
677+
}
678+
679+
inline Status AmplitudeDampingChannel(const Operation& op,
680+
const unsigned int num_qubits,
681+
const unsigned int time,
682+
NoisyQsimCircuit* ncircuit) {
683+
int q;
684+
bool unused;
685+
float gamma;
686+
Status u;
687+
unused = absl::SimpleAtoi(op.qubits(0).id(), &q);
688+
689+
u = ParseProtoArg(op, "gamma", {}, &gamma);
690+
if (!u.ok()) {
691+
return u;
692+
}
693+
auto chan = qsim::Cirq::AmplitudeDampingChannel<float>::Create(
694+
time, num_qubits - q - 1, gamma);
695+
ncircuit->channels.push_back(chan);
696+
return Status::OK();
697+
}
698+
699+
inline Status PhaseDampingChannel(const Operation& op,
700+
const unsigned int num_qubits,
701+
const unsigned int time,
702+
NoisyQsimCircuit* ncircuit) {
703+
int q;
704+
bool unused;
705+
float gamma;
706+
Status u;
707+
unused = absl::SimpleAtoi(op.qubits(0).id(), &q);
708+
709+
u = ParseProtoArg(op, "gamma", {}, &gamma);
710+
if (!u.ok()) {
711+
return u;
712+
}
713+
714+
auto chan = qsim::Cirq::PhaseDampingChannel<float>::Create(
715+
time, num_qubits - q - 1, gamma);
716+
ncircuit->channels.push_back(chan);
717+
return Status::OK();
718+
}
719+
720+
inline Status PhaseFlipChannel(const Operation& op,
721+
const unsigned int num_qubits,
722+
const unsigned int time,
723+
NoisyQsimCircuit* ncircuit) {
724+
int q;
725+
bool unused;
726+
float p;
727+
Status u;
728+
unused = absl::SimpleAtoi(op.qubits(0).id(), &q);
729+
730+
u = ParseProtoArg(op, "p", {}, &p);
731+
if (!u.ok()) {
732+
return u;
733+
}
734+
735+
auto chan =
736+
qsim::Cirq::PhaseFlipChannel<float>::Create(time, num_qubits - q - 1, p);
737+
ncircuit->channels.push_back(chan);
738+
return Status::OK();
739+
}
740+
644741
inline Status BitFlipChannel(const Operation& op, const unsigned int num_qubits,
645742
const unsigned int time,
646743
NoisyQsimCircuit* ncircuit) {
@@ -654,6 +751,7 @@ inline Status BitFlipChannel(const Operation& op, const unsigned int num_qubits,
654751
if (!u.ok()) {
655752
return u;
656753
}
754+
657755
auto chan =
658756
qsim::Cirq::BitFlipChannel<float>::Create(time, num_qubits - q - 1, p);
659757
ncircuit->channels.push_back(chan);
@@ -668,9 +766,11 @@ tensorflow::Status ParseAppendChannel(const Operation& op,
668766
static const absl::flat_hash_map<
669767
std::string, std::function<Status(const Operation&, const unsigned int,
670768
const unsigned int, NoisyQsimCircuit*)>>
671-
chan_func_map = {{"DP", &DepolarizingChannel},
672-
{"ADP", &AsymmetricDepolarizingChannel},
673-
{"BF", &BitFlipChannel}};
769+
chan_func_map = {
770+
{"DP", &DepolarizingChannel}, {"ADP", &AsymmetricDepolarizingChannel},
771+
{"GAD", &GADChannel}, {"AD", &AmplitudeDampingChannel},
772+
{"RST", &ResetChannel}, {"PD", &PhaseDampingChannel},
773+
{"PF", &PhaseFlipChannel}, {"BF", &BitFlipChannel}};
674774

675775
auto build_f = chan_func_map.find(op.gate().id());
676776
if (build_f == chan_func_map.end()) {

0 commit comments

Comments
 (0)