Skip to content

Commit cb4d0fb

Browse files
committed
Fix FOM comparison logic
1 parent fcab4fa commit cb4d0fb

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/mqt/predictor/rl/predictorenv.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(
137137
for elem in action_dict[PassType.OPT]:
138138
self.action_set[index] = elem
139139
self.actions_opt_indices.append(index)
140-
if getattr(elem, "preserve", False):
140+
if getattr(elem, "preserve_layout", False):
141141
self.actions_mapping_preserving_indices.append(index)
142142
index += 1
143143
for elem in action_dict[PassType.LAYOUT]:
@@ -375,7 +375,13 @@ def fom_aware_compile(
375375
"""
376376
best_result = None
377377
best_property_set = None
378-
best_fom = -1.0
378+
maximize = self.reward_function in [
379+
"expected_fidelity",
380+
"estimated_success_probability",
381+
"estimated_hellinger_distance",
382+
"critical_depth",
383+
]
384+
best_fom = -1.0 if maximize else float("inf")
379385
best_swap_count = float("inf") # for fallback
380386

381387
assert callable(action.transpile_pass), "Mapping action should be callable"
@@ -387,11 +393,7 @@ def fom_aware_compile(
387393

388394
try:
389395
# Synthesize for lookahead fidelity (Mapping could insert non-local SWAP gates)
390-
if self.reward_function in [
391-
"expected_fidelity",
392-
"estimated_success_probability",
393-
"estimated_hellinger_distance",
394-
]:
396+
if maximize:
395397
synth_pass = PassManager([
396398
BasisTranslator(StandardEquivalenceLibrary, target_basis=device.operation_names)
397399
])
@@ -412,7 +414,7 @@ def fom_aware_compile(
412414
except Exception as e:
413415
logger.warning(f"[Fallback to SWAP counts] Synthesis or fidelity computation failed: {e}")
414416
swap_count = out_circ.count_ops().get("swap", 0)
415-
if best_result is None or (best_fom == -1.0 and swap_count < best_swap_count):
417+
if best_result is None or swap_count < best_swap_count:
416418
best_swap_count = swap_count
417419
best_result = out_circ
418420
best_property_set = prop_set
@@ -424,7 +426,7 @@ def fom_aware_compile(
424426
if best_result is not None:
425427
return best_result, best_property_set
426428
logger.error("All attempts failed.")
427-
return qc, {}
429+
return qc, None
428430

429431
def _apply_qiskit_action(self, action: Action, action_index: int) -> QuantumCircuit:
430432
pm_property_set: PropertySet | None = {}

0 commit comments

Comments
 (0)