Skip to content

Commit 1c49a5b

Browse files
Liu KeyuLiu Keyu
authored andcommitted
Fix bugs
1 parent 2692b96 commit 1c49a5b

File tree

4 files changed

+45
-97
lines changed

4 files changed

+45
-97
lines changed

src/mqt/predictor/rl/actions.py

Lines changed: 25 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from qiskit.transpiler import CouplingMap
5151
from qiskit.transpiler.passes import (
5252
ApplyLayout,
53-
BasicSwap,
5453
BasisTranslator,
5554
Collect2qBlocks,
5655
CollectCliffords,
@@ -125,6 +124,7 @@ class Action:
125124
pass_type: PassType
126125
transpile_pass: (
127126
list[qiskit_BasePass | tket_BasePass]
127+
| Callable[..., list[Any]]
128128
| Callable[..., list[qiskit_BasePass | tket_BasePass]]
129129
| Callable[
130130
...,
@@ -144,7 +144,8 @@ class DeviceDependentAction(Action):
144144
"""Action that represents a device-specific compilation pass that can be applied to a specific device."""
145145

146146
transpile_pass: (
147-
Callable[..., list[qiskit_BasePass | tket_BasePass]]
147+
Callable[..., list[Any]]
148+
| Callable[..., list[qiskit_BasePass | tket_BasePass]]
148149
| Callable[
149150
...,
150151
Callable[..., tuple[Any, ...] | Circuit],
@@ -466,7 +467,7 @@ def get_openqasm_gates() -> list[str]:
466467

467468
register_action(
468469
DeviceDependentAction(
469-
"SabreLayout+BasicSwap",
470+
"SabreLayout+AIRouting",
470471
CompilationOrigin.QISKIT,
471472
PassType.MAPPING,
472473
stochastic=True,
@@ -482,29 +483,18 @@ def get_openqasm_gates() -> list[str]:
482483
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
483484
EnlargeWithAncilla(),
484485
ApplyLayout(),
485-
BasicSwap(coupling_map=CouplingMap(device.build_coupling_map())),
486+
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="improve"),
486487
],
487488
)
488489
)
489490

490491
register_action(
491492
DeviceDependentAction(
492-
"SabreLayout+AIRouting",
493+
"AIRouting",
493494
CompilationOrigin.QISKIT,
494495
PassType.MAPPING,
495496
stochastic=True,
496-
transpile_pass=lambda device, max_iteration=(20, 20): [
497-
SabreLayout(
498-
coupling_map=CouplingMap(device.build_coupling_map()),
499-
skip_routing=True,
500-
layout_trials=max_iteration[0],
501-
swap_trials=max_iteration[1],
502-
max_iterations=4,
503-
seed=None,
504-
),
505-
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
506-
EnlargeWithAncilla(),
507-
ApplyLayout(),
497+
transpile_pass=lambda device: [
508498
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="optimize"),
509499
],
510500
)
@@ -532,20 +522,6 @@ def get_openqasm_gates() -> list[str]:
532522
)
533523
)
534524

535-
register_action(
536-
DeviceDependentAction(
537-
name="DenseLayout+BasicSwap",
538-
origin=CompilationOrigin.QISKIT,
539-
pass_type=PassType.MAPPING,
540-
transpile_pass=lambda device: [
541-
DenseLayout(coupling_map=CouplingMap(device.build_coupling_map())),
542-
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
543-
EnlargeWithAncilla(),
544-
ApplyLayout(),
545-
BasicSwap(coupling_map=CouplingMap(device.build_coupling_map())),
546-
],
547-
)
548-
)
549525

550526
register_action(
551527
DeviceDependentAction(
@@ -579,45 +555,11 @@ def get_openqasm_gates() -> list[str]:
579555
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
580556
EnlargeWithAncilla(),
581557
ApplyLayout(),
582-
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="optimize"),
558+
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="improve"),
583559
],
584560
)
585561
)
586562

587-
register_action(
588-
DeviceDependentAction(
589-
name="VF2Layout+BasicSwap",
590-
origin=CompilationOrigin.QISKIT,
591-
pass_type=PassType.MAPPING,
592-
transpile_pass=lambda device: [
593-
VF2Layout(
594-
coupling_map=CouplingMap(device.build_coupling_map()),
595-
target=device,
596-
),
597-
ConditionalController(
598-
[
599-
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
600-
EnlargeWithAncilla(),
601-
ApplyLayout(),
602-
],
603-
condition=lambda property_set: property_set["VF2Layout_stop_reason"]
604-
== VF2LayoutStopReason.SOLUTION_FOUND,
605-
),
606-
ConditionalController(
607-
[
608-
TrivialLayout(coupling_map=CouplingMap(device.build_coupling_map())),
609-
FullAncillaAllocation(coupling_map=CouplingMap(device.build_coupling_map())),
610-
EnlargeWithAncilla(),
611-
ApplyLayout(),
612-
],
613-
# Run if VF2Layout did not find a solution
614-
condition=lambda property_set: property_set["VF2Layout_stop_reason"]
615-
!= VF2LayoutStopReason.SOLUTION_FOUND,
616-
),
617-
BasicSwap(coupling_map=CouplingMap(device.build_coupling_map())),
618-
],
619-
)
620-
)
621563

622564
register_action(
623565
DeviceDependentAction(
@@ -691,7 +633,7 @@ def get_openqasm_gates() -> list[str]:
691633
condition=lambda property_set: property_set["VF2Layout_stop_reason"]
692634
!= VF2LayoutStopReason.SOLUTION_FOUND,
693635
),
694-
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="optimize"),
636+
SafeAIRouting(coupling_map=device.build_coupling_map(), optimization_level=3, layout_mode="improve"),
695637
],
696638
)
697639
)
@@ -707,22 +649,22 @@ def get_openqasm_gates() -> list[str]:
707649
)
708650
)
709651

710-
register_action(
711-
DeviceDependentAction(
712-
"BQSKitSynthesis",
713-
CompilationOrigin.BQSKIT,
714-
PassType.SYNTHESIS,
715-
transpile_pass=lambda device: lambda bqskit_circuit: bqskit_compile(
716-
bqskit_circuit,
717-
model=MachineModel(bqskit_circuit.num_qudits, gate_set=get_bqskit_native_gates(device)),
718-
optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2,
719-
synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8,
720-
max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" else 3,
721-
seed=10,
722-
num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1,
723-
),
724-
)
725-
)
652+
# register_action(
653+
# DeviceDependentAction(
654+
# "BQSKitSynthesis",
655+
# CompilationOrigin.BQSKIT,
656+
# PassType.SYNTHESIS,
657+
# transpile_pass=lambda device: lambda bqskit_circuit: bqskit_compile(
658+
# bqskit_circuit,
659+
# model=MachineModel(bqskit_circuit.num_qudits, gate_set=get_bqskit_native_gates(device)),
660+
# optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2,
661+
# synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8,
662+
# max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" else 3,
663+
# seed=10,
664+
# num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1,
665+
# ),
666+
# )
667+
# )
726668

727669
register_action(
728670
DeviceIndependentAction(

src/mqt/predictor/rl/helper.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,20 @@ def best_of_n_passmanager(
7878
msg = f"Expected list of passes, got {type(all_passes)}"
7979
raise TypeError(msg)
8080

81-
layout_passes = all_passes[:-1]
82-
routing_pass = all_passes[-1:]
81+
if len(all_passes) == 1:
82+
layouted_qc = qc
83+
layout_props = {}
84+
routing_pass = all_passes
85+
else:
86+
layout_passes = all_passes[:-1]
87+
routing_pass = all_passes[-1:]
8388

84-
# Run layout once
85-
layout_pm = PassManager(layout_passes)
86-
try:
87-
layouted_qc = layout_pm.run(qc)
88-
layout_props = dict(layout_pm.property_set)
89-
except Exception:
90-
return qc, {}
89+
layout_pm = PassManager(layout_passes)
90+
try:
91+
layouted_qc = layout_pm.run(qc)
92+
layout_props = dict(layout_pm.property_set)
93+
except Exception:
94+
return qc, {}
9195

9296
# Run routing multiple times and optimize for the given metric
9397
for i in range(max_iteration[1]):

src/mqt/predictor/rl/predictorenv.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def step(self, action: int) -> tuple[dict[str, Any], float, bool, bool, dict[Any
188188
RuntimeError: If no valid actions are left.
189189
"""
190190
self.used_actions.append(str(self.action_set[action].name))
191+
logger.info(f"Applying: {self.action_set[action].name!s}")
191192
altered_qc = self.apply_action(action)
192193
if not altered_qc:
193194
return (
@@ -466,18 +467,19 @@ def _apply_bqskit_action(self, action: Action, action_index: int) -> QuantumCirc
466467

467468
def determine_valid_actions_for_state(self) -> list[int]:
468469
"""Determines and returns the valid actions for the current state."""
469-
check_nat_gates = GatesInBasis(target=self.device)
470+
check_nat_gates = GatesInBasis(basis_gates=self.device.operation_names)
470471
check_nat_gates(self.state)
471472
only_nat_gates = check_nat_gates.property_set["all_gates_in_basis"]
472473

473474
check_mapping = CheckMap(coupling_map=CouplingMap(self.device.build_coupling_map()))
474475
check_mapping(self.state)
475476
mapped = check_mapping.property_set["is_swap_mapped"]
476477

478+
# logger.info(f"Native: {only_nat_gates}, Mapped: {mapped}, Layout:{self.layout}")
477479
if not only_nat_gates: # not native gates yet
478480
return self.actions_synthesis_indices + self.actions_opt_indices
479481

480-
if mapped and self.layout is not None: # The circuit is correctly mapped.
482+
if mapped and self.layout is not None: # The circuit is correctly mapped
481483
return [self.action_terminate_index, *self.actions_opt_indices, *self.actions_final_optimization_indices]
482484
# The circuit is not mapped yet
483485
# Or the circuit was mapped but some optimization actions change its structure and the circuit is again unmapped

tests/compilation/test_predictor_rl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_predictor_env_reset_from_string() -> None:
3737
device = get_device("ibm_eagle_127")
3838
predictor = Predictor(figure_of_merit="expected_fidelity", device=device)
3939
qasm_path = Path("test.qasm")
40-
qc = get_benchmark("dj", BenchmarkLevel.ALG, 3)
40+
qc = get_benchmark("dj", BenchmarkLevel.INDEP, 3)
4141
with qasm_path.open("w", encoding="utf-8") as f:
4242
dump(qc, f)
4343
assert predictor.env.reset(qc=qasm_path)[0] == create_feature_dict(qc)
@@ -69,7 +69,7 @@ def test_qcompile_with_newly_trained_models() -> None:
6969
"""
7070
figure_of_merit = "expected_fidelity"
7171
device = get_device("ibm_falcon_127")
72-
qc = get_benchmark("ghz", BenchmarkLevel.ALG, 3)
72+
qc = get_benchmark("ghz", BenchmarkLevel.INDEP, 3)
7373
predictor = Predictor(figure_of_merit=figure_of_merit, device=device)
7474

7575
model_name = "model_" + figure_of_merit + "_" + device.description
@@ -95,7 +95,7 @@ def test_qcompile_with_newly_trained_models() -> None:
9595

9696
def test_qcompile_with_false_input() -> None:
9797
"""Test the qcompile function with false input."""
98-
qc = get_benchmark("dj", BenchmarkLevel.ALG, 5)
98+
qc = get_benchmark("dj", BenchmarkLevel.INDEP, 5)
9999
with pytest.raises(ValueError, match=re.escape("figure_of_merit must not be None if predictor_singleton is None.")):
100100
rl_compile(qc, device=get_device("quantinuum_h2_56"), figure_of_merit=None)
101101
with pytest.raises(ValueError, match=re.escape("device must not be None if predictor_singleton is None.")):

0 commit comments

Comments
 (0)