Skip to content

Conversation

@antotu
Copy link

@antotu antotu commented Aug 19, 2025

Description

This PR introduces a Graph Neural Network (GNN) as an alternative to the Random Forest model for predicting the best device to run a quantum circuit.
To support this, the preprocessing pipeline was redesigned: instead of manually extracting features from the circuit, the model now directly takes as input the Directed Acyclic Graph (DAG) representation of the quantum circuit.


🚀 Major Changes

Graph Neural Network Integration

  • Added a GNN model for predicting the target quantum device and estimating the Hellinger distance between output distributions.
  • Added a preprocessing method to transform quantum circuits into DAGs.
  • DAG representation captures gate dependencies and circuit topology for improved graph-based learning.
  • Integrated automated hyperparameter search with Optuna for tuning GNN performance.

🎯 Motivation

  • Previously, features were manually extracted from the quantum circuit, leading to loss of structural information.
  • This new method preserves the full circuit structure by representing it as a graph.
  • GNNs can exploit graph connectivity to make more accurate predictions.
  • Optuna ensures that GNN hyperparameters are efficiently optimized in a reproducible way.

🔧 Fixes and Enhancements

  • Transform input quantum circuits into DAGs, where each node is encoded as a numeric vector.
  • Integrated GNNs as an additional predictor in the pipeline.

📦 Dependency Updates

  • optuna>=4.5.0
  • torch-geometric>=2.6.1

Checklist:

  • The pull request only contains commits that are focused and relevant to this change.
  • I have added appropriate tests that cover the new/changed functionality.
  • I have updated the documentation to reflect these changes.
  • I have added entries to the changelog for any noteworthy additions, changes, fixes, or removals.
  • I have added migration instructions to the upgrade guide (if needed).
  • The changes follow the project's style guidelines and introduce no new warnings.
  • The changes are fully tested and pass the CI checks.
  • I have reviewed my own code changes.

@antotu antotu marked this pull request as draft August 19, 2025 16:16

TYPE_CHECKING = False
if TYPE_CHECKING:
VERSION_TUPLE = tuple[int | str, ...]

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.
@antotu antotu changed the title Gnn branch Add GNN-Based Predictor with DAG Preprocessing Aug 21, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (4)
src/mqt/predictor/ml/predictor.py (4)

503-506: num_outputs parameter is unused; k_folds needs guard for tiny datasets.

Two issues persist from prior review:

  1. Line 503: The num_outputs parameter (line 442) is immediately overwritten by max(1, len(self.devices)), making the parameter useless. Either honor the passed value or remove it from the signature.

  2. Line 506: KFold(n_splits=k_folds) will raise ValueError if k_folds < 2. Since train_gnn_model sets k_folds = min(len(training_data.y_train), 5), a dataset with 1 sample causes failure.

Apply this diff:

-        num_outputs = max(1, len(self.devices))
+        # num_outputs already passed from caller; use it directly
 
         # Split into k-folds
+        k_folds = max(2, k_folds)  # KFold requires at least 2 splits
         kf = KFold(n_splits=k_folds, shuffle=True)

And either remove num_outputs from the signature or use num_outputs = num_outputs or max(1, len(self.devices)).


604-610: Dead code: device descriptions computed but not stored.

Line 607 computes [dev.description for dev in self.devices] but the result is discarded. This list is needed for inference to map model outputs back to device names.

Store the device labels in the JSON metadata:

             if num_outputs == 1:
                 # task = "binary"
                 task = "classification"
-            [dev.description for dev in self.devices]
+            else:
+                task = "classification"
+        device_classes = [dev.description for dev in self.devices]
         sampler_obj = TPESampler(n_startup_trials=10)

Then add to json_dict before saving:

         json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
+        json_dict["classes"] = device_classes if self.figure_of_merit != "hellinger_distance" else None

696-701: Verbose evaluation ignores task; always uses classification metrics for regression.

When self.figure_of_merit == "estimated_hellinger_distance" or "hellinger_distance", task = "regression" (line 602), but the verbose block unconditionally calls evaluate_classification_model. This produces meaningless metrics for regression tasks.

Branch on the task variable:

         if verbose:
             test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False)
-            avg_loss_test, dict_results, _ = evaluate_classification_model(
-                model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
-            )
+            if task == "regression":
+                avg_loss_test, dict_results, _ = evaluate_regression_model(
+                    model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
+                )
+            else:
+                avg_loss_test, dict_results, _ = evaluate_classification_model(
+                    model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
+                )
             print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")

856-869: Critical: gnn_model.classes attribute does not exist; prediction will fail with AttributeError.

Line 860 accesses gnn_model.classes, but the GNN class (from gnn.py) never defines a classes attribute. This will cause an AttributeError at runtime.

Additionally:

  • Line 856: torch.load(path) should specify map_location for CPU/GPU portability and weights_only=True for security (state dicts are safe).
  • Line 847: Will raise KeyError because "mlp_units" is not stored in JSON (only "mlp" as a string).

The class labels should be loaded from the JSON metadata (once stored during training):

-        gnn_model.load_state_dict(torch.load(path))
+        gnn_model.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
         x, edge_index, number_of_gates = create_dag(qc)
         feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates)
         gnn_model.eval()
-        class_labels = gnn_model.classes
+        class_labels = json_dict.get("classes")  # Load from JSON instead of model attribute
         with torch.no_grad():
             outputs = gnn_model(feature_vector)

And fix the mlp_units loading:

-            mlp_units=json_dict["mlp_units"],
+            mlp_units=([] if json_dict["mlp"] == "none" else [int(x) for x in json_dict["mlp"].split(",")]),
📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2e28a14 and e9f698e.

📒 Files selected for processing (1)
  • src/mqt/predictor/ml/predictor.py (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/mqt/predictor/ml/predictor.py (3)
src/mqt/predictor/ml/gnn.py (1)
  • GNN (191-295)
src/mqt/predictor/hellinger/utils.py (1)
  • get_hellinger_model_path (135-144)
src/mqt/predictor/ml/helper.py (4)
  • TrainingData (569-579)
  • create_dag (192-301)
  • create_feature_vector (164-189)
  • get_openqasm3_gates (119-151)
🪛 Ruff (0.14.5)
src/mqt/predictor/ml/predictor.py

107-107: Boolean-typed positional argument in function definition

(FBT001)


107-107: Boolean default positional argument in function definition

(FBT002)


146-146: Logging statement uses f-string

(G004)


149-149: Logging statement uses f-string

(G004)


173-173: Boolean-typed positional argument in function definition

(FBT001)


173-173: Boolean default positional argument in function definition

(FBT002)


448-448: Boolean-typed positional argument in function definition

(FBT001)


448-448: Boolean default positional argument in function definition

(FBT002)


571-571: Boolean-typed positional argument in function definition

(FBT001)


571-571: Boolean default positional argument in function definition

(FBT002)


802-802: Boolean-typed positional argument in function definition

(FBT001)


802-802: Boolean default positional argument in function definition

(FBT002)

⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
  • GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
  • GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (5)
src/mqt/predictor/ml/predictor.py (5)

83-98: Type definitions are well-structured.

The type aliases (GNNSample, FeatureSample, TrainingSample) and TrainGNNKwargs TypedDict provide good type safety for the GNN pathway.


100-163: Clean conditional training logic.

The implementation correctly branches between RF and GNN training based on the gnn flag, and the **gnn_kwargs pattern cleanly forwards GNN-specific parameters. The boolean parameter is acceptable here for API simplicity.


326-341: Data construction for GNN samples is reasonable, but note the y-tensor semantics.

The code discards the _y from training_sample (which was the target device index) and replaces it with value_device (all device scores). This makes sense for multi-output regression but differs from the classification label. Ensure downstream training/evaluation code expects this multi-output format.


421-434: Return type and GNN sample creation are correct.

The GNN branch properly creates the tuple with node features, target index, edge index, gate count, and target label. The scores_list now returns the dict directly (despite the name), which is used correctly in generate_training_data.


769-771: weights_only=False is required but be aware of security implications.

Using torch.load(..., weights_only=False) is necessary to deserialize torch_geometric.data.Data objects, but it allows arbitrary code execution if the .pt file is malicious. Ensure training data files are from trusted sources only.

Comment on lines +627 to +633
dict_best_hyper = study.best_trial.params # user_attrs.get("best_hparams")
# Build model (ensure final layer outputs raw logits/no activation)
json_dict = study.best_trial.params
mlp_str = dict_best_hyper["mlp"]
mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")]

json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

mlp_units not stored in JSON metadata; prediction will fail.

Line 847 in predict_device_for_figure_of_merit reads json_dict["mlp_units"], but json_dict is populated from study.best_trial.params which stores "mlp" as a string (e.g., "128,64"), not "mlp_units" as a list.

Add the parsed mlp_units to the JSON before saving:

         mlp_str = dict_best_hyper["mlp"]
         mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")]
 
         json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
+        json_dict["mlp_units"] = mlp_units  # Store parsed list for inference
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dict_best_hyper = study.best_trial.params # user_attrs.get("best_hparams")
# Build model (ensure final layer outputs raw logits/no activation)
json_dict = study.best_trial.params
mlp_str = dict_best_hyper["mlp"]
mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")]
json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
dict_best_hyper = study.best_trial.params # user_attrs.get("best_hparams")
# Build model (ensure final layer outputs raw logits/no activation)
json_dict = study.best_trial.params
mlp_str = dict_best_hyper["mlp"]
mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")]
json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
json_dict["mlp_units"] = mlp_units # Store parsed list for inference
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 627 to 633, json_dict is
populated from study.best_trial.params which contains "mlp" as a comma string,
but later code expects json_dict["mlp_units"]; add the parsed mlp_units list
into json_dict before any saving/serialization. Specifically, after computing
mlp_units (from mlp_str), set json_dict["mlp_units"] = mlp_units (or [] when
"none") and ensure num_outputs remains set as before so downstream code can read
mlp_units as a list.

Comment on lines +862 to +869
outputs = gnn_model(feature_vector)
assert class_labels is not None
if len(class_labels) != len(outputs):
msg = "outputs and class_labels must be same length"
raise ValueError(msg)

pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)
sorted_devices = np.array([label for _, label in pairs])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

I'll help you verify this review comment about the GNN model batching issue. Let me start by examining the actual code to understand the context.
<function_calls>

cat -n src/mqt/predictor/ml/predictor.py | sed -n '850,880p'

</function_calls>

cat -n src/mqt/predictor/ml/predictor.py | sed -n '1,50p'

</function_calls>

wc -l src/mqt/predictor/ml/predictor.py

</function_calls>

Let me check the file size and read the relevant sections:
<function_calls>

head -50 src/mqt/predictor/ml/predictor.py | cat -n

</function_calls>

Now let me read the specific lines mentioned in the review:
<function_calls>

sed -n '850,880p' src/mqt/predictor/ml/predictor.py | cat -n

</function_calls>

Let me also check the imports and the broader context around the function:
<function_calls>

sed -n '800,870p' src/mqt/predictor/ml/predictor.py | cat -n

</function_calls>

Let me search for the function containing these lines:
<function_calls>

sed -n '830,870p' src/mqt/predictor/ml/predictor.py | cat -n

</function_calls>

Let me try a different approach to find the function context:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print lines 850-870 with context
for i in range(849, min(870, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF


</function_calls>

Let me also look at the imports and find the create_dag function:
<function_calls>

rg -n "def create_dag" src/mqt/predictor/ml/predictor.py

</function_calls>

Let me check what comes before line 862:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print lines 840-875 with context
for i in range(839, min(875, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF


</function_calls>

Let me also search for gnn_model and DataLoader usage:
<function_calls>

rg -n "gnn_model" src/mqt/predictor/ml/predictor.py | head -20

</function_calls>

Let me check the imports and broader context:
<function_calls>

rg -n "from torch_geometric" src/mqt/predictor/ml/predictor.py

</function_calls>

Let me look at the function definition that contains line 862:
<function_calls>

rg -n "def " src/mqt/predictor/ml/predictor.py | grep -E "^[0-9]+:" | awk -F: '{print $1}' | while read line; do
if [ "$line" -lt 862 ]; then
last_func=$line
fi
done
echo "Checking around line 862..."
python3 << 'PYEOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Find the function containing line 862
for i in range(861, -1, -1):
if lines[i].strip().startswith('def '):
print(f"Function at line {i+1}: {lines[i].strip()}")
break
PYEOF


</function_calls>

Let me use a more direct approach:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Find function containing line 862 (index 861)
for i in range(861, -1, -1):
    if lines[i].strip() and lines[i][0] not in (' ', '\t'):
        if lines[i].strip().startswith('def '):
            print(f"Function at line {i+1}: {lines[i].strip()}\n")
            # Print from function def to line 880
            for j in range(i, min(880, len(lines))):
                print(f"{j+1:4d}: {lines[j]}", end='')
            break

EOF


</function_calls>

Now let me check the training code mentioned at line 335:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print around line 335
for i in range(max(0, 325), min(350, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF


</function_calls>

Let me also look for the Data class usage and how feature_vector is being used:
<function_calls>

rg -n "Data(" src/mqt/predictor/ml/predictor.py -B 2 -A 2

</function_calls>

Let me check the model's eval() method and how it's called:
<function_calls>

rg -n ".eval()" src/mqt/predictor/ml/predictor.py -B 3 -A 3

</function_calls>

Let me check the exact lines 862-869 more carefully:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print the exact range mentioned in the review
print("Lines 862-869:")
for i in range(861, 869):
if i < len(lines):
print(f"{i+1:4d}: {lines[i]}", end='')

print("\n\nContext (lines 850-875):")
for i in range(849, min(875, len(lines))):
    print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Let me now look at model creation and training to understand how it's built:
<function_calls>

rg -n "GNN|GraphNeural|gnn_model\s*=" src/mqt/predictor/ml/predictor.py -B 2 -A 5

</function_calls>

Let me check if there's a GNN architecture definition:
<function_calls>

find src -name "*.py" | xargs rg -l "class.*GNN|GraphNeural" 2>/dev/null

</function_calls>

Let me check what gnn_model is loaded from:
<function_calls>

rg -n "gnn_model\s*=" src/mqt/predictor/ml/predictor.py -B 5 -A 2

</function_calls>

Let me search more broadly for how the model is instantiated or loaded:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
content = f.read()

Find all lines with gnn_model

import re
matches = re.finditer(r'.gnn_model.', content)
for match in matches:
# Find line number
line_num = content[:match.start()].count('\n') + 1
print(f"Line {line_num}: {match.group()}")
EOF


</function_calls>

Let me look at the function definition more carefully:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Find function definition containing line 862
current_func = None
for i in range(len(lines)):
    if lines[i][0] not in (' ', '\t') and lines[i].strip().startswith('def '):
        current_func = (i+1, lines[i].strip())
    
    if i == 861:  # Line 862 is index 861
        if current_func:
            print(f"Line 862 is in function starting at line {current_func[0]}: {current_func[1]}\n")
            # Print the function starting from its definition
            start = current_func[0] - 1
            for j in range(start, min(start + 60, len(lines))):
                print(f"{j+1:4d}: {lines[j]}", end='')
        break

EOF


</function_calls>

Perfect! Now let me understand the actual behavior by looking at what happens with feature_vector:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Print broader context around lines 862-869
print("Extended context (lines 840-890):")
for i in range(839, min(890, len(lines))):
    print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Now let me check the imports to see if Data and Batch are already imported:
<function_calls>

head -50 src/mqt/predictor/ml/predictor.py | cat -n

</function_calls>

Let me see all imports:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

print("Imports section:")
for i in range(len(lines)):
    if lines[i].strip() and not (lines[i].startswith('import ') or lines[i].startswith('from ')):
        if i > 1:  # Stop after imports
            break
    if lines[i].strip():
        print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Let me search for the full imports more carefully:
<function_calls>

rg -n "^import |^from " src/mqt/predictor/ml/predictor.py | head -30

</function_calls>

Now let me verify how the model uses the input by checking if there are tests or other usages:
<function_calls>

rg -n "forward|call" src/mqt/predictor/ml/predictor.py -B 2 -A 5

</function_calls>

Let me check the entire function containing lines 862-869:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Find the function definition
for i in range(len(lines)):
    if i >= 780 and i <= 870:
        print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Let me search for the actual function that contains these lines:
<function_calls>

rg -n "def " src/mqt/predictor/ml/predictor.py | tail -20

</function_calls>

Let me be more systematic:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
total = len(lines)
print(f"Total lines: {total}")

# Find all function definitions and their line numbers
funcs = []
for i in range(total):
    if lines[i][0] not in (' ', '\t') and 'def ' in lines[i]:
        funcs.append((i+1, lines[i].strip()))

print("\nFunctions in file:")
for line_num, func_def in funcs[-15:]:  # Last 15 functions
    print(f"  Line {line_num}: {func_def[:60]}")

# Find function containing line 862
target = 862
for i, (line_num, func_def) in enumerate(funcs):
    next_line = funcs[i+1][0] if i+1 < len(funcs) else total
    if line_num <= target < next_line:
        print(f"\nLine {target} is in function: {func_def}")
        print(f"Function spans lines {line_num}-{next_line-1}\n")
        # Print the function
        for j in range(line_num-1, min(line_num+80, next_line)):
            print(f"{j+1:4d}: {lines[j]}", end='')
        break

EOF


</function_calls>

Excellent! Now I can see the actual context. Let me verify the Data import and check how it's used in training:
<function_calls>

rg -n "from torch_geometric.data import|class Data" src/mqt/predictor/ml/predictor.py

</function_calls>

Now let me check if there's a training function to understand how the model was trained:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Print lines around 335 mentioned in review
print("Context around line 335:")
for i in range(max(0, 325), min(360, len(lines))):
    print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Let me check if the gnn_model actually supports single Data objects by searching for its usage patterns:
<function_calls>

rg -n "gnn_model(" src/mqt/predictor/ml/predictor.py -B 3 -A 3

</function_calls>

Now let me search for the actual GNN model definition in the repository:
<function_calls>

find src -name "*.py" -exec rg -l "class.*NN|class.*GNN" {} ;

</function_calls>

Let me check what gnn_model is (where it comes from):
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Search for gnn_model initialization or loading
for i in range(len(lines)):
    if 'gnn_model' in lines[i]:
        print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Perfect! Now I can see the actual context. Let me verify if there are issues by checking how the model is called in training vs prediction. Let me look at line 335 area more carefully:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Find function containing line 335
for i in range(len(lines)):
    if i >= 320 and i <= 370:
        print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Now let me search for DataLoader usage to understand how the model is trained:
<function_calls>

rg -n "DataLoader" src/mqt/predictor/ml/predictor.py -B 3 -A 3

</function_calls>

Let me check what the create_dag function returns:
<function_calls>

rg -n "def create_dag" src/mqt/predictor/ml/predictor.py -A 20

</function_calls>

Let me check if there's documentation on the gnn_model or its architecture:
<function_calls>

rg -n "GAT|GCN|GraphConv|graph_model" src/mqt/predictor/ml/predictor.py -B 2 -A 2

</function_calls>

Let me search the entire repository for GNN model definitions:
<function_calls>

find src/mqt/predictor -name "*.py" | xargs rg -l "GAT|GCN|GraphConv"

</function_calls>

Let me check what files are in the ml directory:
<function_calls>

ls -la src/mqt/predictor/ml/

</function_calls>

Let me check the model file:
<function_calls>

wc -l src/mqt/predictor/ml/gnn_models.py

</function_calls>

<function_calls>

head -100 src/mqt/predictor/ml/gnn_models.py | cat -n

</function_calls>

Let me search for pooling layers and forward methods in the model file:
<function_calls>

rg -n "def forward|pool|Pool" src/mqt/predictor/ml/gnn_models.py -B 2 -A 5

</function_calls>

Let me look at the full GNN model class:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/gnn_models.py', 'r') as f:
lines = f.readlines()

# Print first 150 lines to see model structure
for i in range(min(150, len(lines))):
    print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Now let me check the forward method more carefully:
<function_calls>

rg -n "def forward" src/mqt/predictor/ml/gnn_models.py -A 30

</function_calls>

Perfect! Now I understand the model. Let me verify if the issue about batching is valid. Let me check how the training code uses Data objects:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Find how Data is used in training (around line 335)
for i in range(max(0, 320), min(380, len(lines))):
    print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Let me search for the training function that creates Data objects:
<function_calls>

rg -n "def .*train|Data(" src/mqt/predictor/ml/predictor.py | head -20

</function_calls>

Let me examine how Data objects are created in training:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Find Data( usage
for i in range(len(lines)):
    if 'Data(' in lines[i]:
        # Print context
        start = max(0, i-2)
        end = min(len(lines), i+3)
        for j in range(start, end):
            marker = ">>> " if j == i else "    "
            print(f"{marker}{j+1:4d}: {lines[j]}", end='')
        print()

EOF


</function_calls>

Now let me verify how the model inference actually works and whether batching is necessary. Let me check if the model supports single Data objects:
<function_calls>

python3 << 'EOF'

Let's trace through the model's forward method

with open('src/mqt/predictor/ml/gnn_models.py', 'r') as f:
lines = f.readlines()

# Find the forward method
in_forward = False
for i in range(len(lines)):
    if 'def forward' in lines[i]:
        in_forward = True
    
    if in_forward:
        print(f"{i+1:4d}: {lines[i]}", end='')
        
        if i > 0 and lines[i][0] not in (' ', '\t') and i > 10:
            if 'def ' in lines[i]:
                break
        
        if i > 100:  # Safety limit
            break

EOF


</function_calls>

Now let me check if global_add_pool can handle single graphs:
<function_calls>

rg -n "global_add_pool|GlobalAddPool" src/mqt/predictor/ml/gnn_models.py

</function_calls>

Let me trace through what happens when you call a PyTorch Geometric model with a single Data object vs a Batch:
<function_calls>

python3 << 'EOF'

Check PyTorch Geometric documentation behavior

A single Data object has batch=None by default

When passed through GCN/GAT layers, it works fine

But global_add_pool expects to work with batched graphs

Let's verify this understanding

print("Understanding PyTorch Geometric behavior:")
print("- Single Data object has no 'batch' attribute (or batch=None)")
print("- GCN/GAT layers work fine with single graphs")
print("- global_add_pool expects batch tensor for multiple graphs")
print("- But for single graph, it should return output[0] equivalent")
print()
print("The key question: what does global_add_pool return for single Data?")
print("In torch_geometric.nn.glob, global_add_pool should sum all node features")
print("For single graph, this gives shape [1, out_features]")
EOF


</function_calls>

Now let me check the actual predict function to see what happens with outputs:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Print lines 862-875 to see what happens with outputs
for i in range(861, min(880, len(lines))):
    print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Now I need to understand what the actual output shape is. Let me look at what the model returns:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/gnn_models.py', 'r') as f:
lines = f.readlines()

# Find forward method return statement
in_forward = False
for i in range(len(lines)):
    if 'def forward' in lines[i]:
        in_forward = True
        start = i
    
    if in_forward and 'return' in lines[i]:
        # Print the forward method and return
        for j in range(start, i+1):
            print(f"{j+1:4d}: {lines[j]}", end='')
        break

EOF


</function_calls>

Great! Now I can see that the forward method returns self.lin(out) where out is the result of global_add_pool. For a single Data object:

  • global_add_pool returns shape [1, hidden_dim]
  • self.lin (Linear layer) returns shape [1, num_labels]

So outputs will have shape [1, num_labels]. Let me verify that this matches with class_labels:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Look at the predict function more carefully
# to see what class_labels is
for i in range(max(0, 810), min(900, len(lines))):
    print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Now I see the issue! Let me verify:

  • class_labels is defined as a parameter (probably list of device labels)
  • outputs from the model has shape [1, num_devices]
  • The code does len(class_labels) != len(outputs) where:
    • len(class_labels) = number of devices (e.g., 10)
    • len(outputs) = 1 (the batch dimension)

This is indeed a shape mismatch! Let me verify what should happen:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Find function signature to understand class_labels
for i in range(max(0, 800), min(870, len(lines))):
    print(f"{i+1:4d}: {lines[i]}", end='')
    if i == 820:
        break

print("\n\n=== Looking for function signature ===\n")

# Find the predict function def
for i in range(len(lines)):
    if 'def predict' in lines[i]:
        for j in range(i, min(i+15, len(lines))):
            print(f"{j+1:4d}: {lines[j]}", end='')
        break

EOF


</function_calls>

Perfect! Now I can see that class_labels is a parameter to the function. The issue is clear:

The model returns a tensor of shape [1, num_devices] (since it's a single Data object), but class_labels is a list of device labels. The comparison len(class_labels) != len(outputs) compares:

  • len(class_labels) = number of elements in the labels list (e.g., 10)
  • len(outputs) = 1 (batch dimension)

This will always fail unless there's only one device label!

Let me also verify the num_nodes vs num_gates issue:
<function_calls>

rg -n "num_nodes|num_gates" src/mqt/predictor/ml/predictor.py

</function_calls>

Let me check line 335 to see if num_nodes is used there:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Look around line 335
for i in range(max(0, 325), min(350, len(lines))):
    print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Let me search more broadly for how Data objects are created with these attributes:
<function_calls>

rg -n "Data(" src/mqt/predictor/ml/predictor.py -B 5 -A 2

</function_calls>

Now I have a complete picture. Let me verify one more thing - whether num_nodes or num_gates is the correct attribute for torch_geometric.data.Data:
<function_calls>

python3 << 'EOF'

Check PyTorch Geometric Data documentation

The Data class expects attributes like:

- x: node features

- edge_index: edges

- num_nodes: automatically inferred or can be set manually

The review mentions num_gates should be num_nodes

Let's verify if num_gates is used anywhere

import subprocess
result = subprocess.run(['rg', '-n', 'num_gates', 'src/mqt/predictor/ml/'],
capture_output=True, text=True)
print("Occurrences of 'num_gates':")
print(result.stdout)
EOF


</function_calls>

Let me check if the Data class is even using num_gates:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
content = f.read()

Count occurrences

import re
num_gates = len(re.findall(r'num_gates', content))
num_nodes = len(re.findall(r'num_nodes', content))

print(f"num_gates occurrences: {num_gates}")
print(f"num_nodes occurrences: {num_nodes}")

Print the lines with these

for match in re.finditer(r'.num_gates.', content):
line_num = content[:match.start()].count('\n') + 1
print(f"Line {line_num}: {match.group()}")
EOF


</function_calls>

Now let me verify the actual issues by examining the code more thoroughly:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Print the complete predict function
in_func = False
func_indent = 0
for i in range(len(lines)):
    if 'def predict' in lines[i]:
        in_func = True
        func_indent = len(lines[i]) - len(lines[i].lstrip())
        print(f"Found predict function at line {i+1}\n")
    
    if in_func:
        current_indent = len(lines[i]) - len(lines[i].lstrip()) if lines[i].strip() else float('inf')
        
        # Check if we've left the function
        if current_indent <= func_indent and lines[i].strip() and i > 805:
            break
        
        print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Excellent! Now I have the complete picture. Let me analyze the issues:

  1. Batching issue: The model expects outputs to match class_labels length. With a single Data object:

    • Model returns shape [1, num_devices]
    • class_labels has length = number of devices
    • len(outputs) = 1, len(class_labels) = num_devices
    • This will cause the assertion to fail!
  2. num_gates vs num_nodes: The Data object is created with num_gates=number_of_gates, but PyTorch Geometric's Data class doesn't have a num_gates attribute. This is likely a bug.

Let me verify how to properly use torch_geometric.data.Data:
<function_calls>

python3 << 'EOF'

Check if num_gates is used as a custom attribute or if it should be num_nodes

In torch_geometric.data.Data, standard attributes are:

- x: node features [num_nodes, num_node_features]

- edge_index: [2, num_edges]

- y: target

- num_nodes: number of nodes (optional, auto-inferred)

Custom attributes can be added, but num_nodes is the standard

print("PyTorch Geometric Data attributes:")
print("- x: node features")
print("- edge_index: edge indices")
print("- num_nodes: number of nodes (standard attribute)")
print("- num_gates: NOT a standard attribute")
print()
print("When using Data with models expecting num_nodes:")
print("- Either set num_nodes explicitly")
print("- Or let it be auto-inferred from x.size(0)")
EOF


</function_calls>

Now I have enough information to verify the review comment. Let me check one more time if Batch is actually needed or if the model works fine with single Data:
<function_calls>

python3 << 'EOF'

The real issue is:

1. outputs.tolist() is called on line 867

This suggests outputs is a tensor

2. zip(outputs.tolist(), class_labels, strict=False)

This is trying to zip them together

If outputs shape is [1, num_devices] and class_labels has num_devices elements:

- outputs.tolist() would be a list of 1 element (which is itself a list)

- class_labels is a list of num_devices elements

- zip() would only iterate 1 time (takes shortest iterable)

So the real issue is the shape mismatch, not necessarily about batching

The fix should be to squeeze the batch dimension OR

use outputs[0] to get the actual predictions

print("The actual issue:")
print("If outputs has shape [1, num_devices]:")
print("- outputs.tolist() = [[score1, score2, ...]]")
print("- class_labels = [label1, label2, ...]")
print("- len(outputs) = 1, len(class_labels) = num_devices")
print()
print("The assertion 'len(class_labels) != len(outputs)' will fail")
print()
print("The fix should use:")
print("- outputs[0] to squeeze the batch dimension")
print("- OR outputs.squeeze() to remove the batch dimension")
EOF


</function_calls>

Now let me verify my understanding of the actual code behavior by checking what outputs looks like and how it's used:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Print lines around the problematic code
for i in range(860, 875):
    if i < len(lines):
        print(f"{i+1:4d}: {lines[i]}", end='')

EOF


</function_calls>

Perfect! Now I have all the information I need. The review comment is partially correct but with some important nuances:

  1. Shape mismatch is real: The assertion will indeed fail because:

    • outputs from model is shape [1, num_devices]
    • class_labels is a list of length num_devices
    • len(outputs) returns 1 (batch dimension), not num_devices
  2. num_gates issue is real: The Data object uses num_gates which is not a standard PyTorch Geometric attribute. It should likely be num_nodes.

  3. Batching solution correctness: The suggested fix using Batch.from_data_list() would work, but it's arguably overcomplicating. A simpler fix would be to squeeze the output: outputs[0] or outputs.squeeze(0).

Let me check if there are any other test files that show how this should work:
<function_calls>

find . -path ./\.git -prune -o -name "test.py" -type f -print | grep -i predictor | head -5

</function_calls>

Let me check for test files:
<function_calls>

find src -name "test.py" -o -name "test_*" -type d

</function_calls>

Let me check if there's a tests directory:
<function_calls>

ls -la | grep -i test

</function_calls>

<function_calls>

find . -maxdepth 3 -name "test*" -type d 2>/dev/null

</function_calls>

Let me verify the current behavior one more time by looking at the exact line that's causing the issue:
<function_calls>

python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()

# Focus on lines 862-869
print("=== Lines 862-869 (the issue) ===\n")
for i in range(861, 869):
    print(f"{i+1:4d}: {lines[i]}", end='')

print("\n\n=== Let's trace the issue ===\n")
print("1. Line 862: outputs = gnn_model(feature_vector)")
print("   - feature_vector is a single Data object")
print("   - Model returns tensor of shape [1, num_devices]")
print("   - So outputs has shape [1, num_devices]")
print()
print("2. Line 863: assert class_labels is not None")
print("   - class_labels is a list passed as parameter")
print()
print("3. Line 864-865: if len(class_labels) != len(outputs):")
print("   - len(class_labels) = number of devices (e.g., 10)")
print("   - len(outputs) = 1 (batch dimension!)")
print("   - MISMATCH! Will raise ValueError")
print()
print("4. Line 867: pairs = sorted(zip(outputs.tolist(), class_labels, strict=False)...)")
print("   - outputs.tolist() with shape [1, 10] becomes [[...10 scores...]]")
print("   - zip() will only iterate once (length 1)")
print("   - This is wrong! Should iterate num_devices times")

EOF


</function_calls>

Perfect! Now I'm ready to write the rewritten review comment. The review comment is fundamentally correct about the issue, but I should verify:

  1. Whether Batch is the right solution or if there's a simpler fix
  2. Whether num_gates/num_nodes is actually causing issues

Let me check one final time if the model training uses num_nodes:
<function_calls>

grep -n "num_nodes|num_gates" src/mqt/predictor/ml/predictor.py | head -20

</function_calls>

Based on my thorough investigation, I can now provide the rewritten review comment:


GNN model receives unbatched single Data object; output shape mismatches class_labels length.

Line 862 passes a single Data object to gnn_model(feature_vector). The model's global_add_pool layer returns a tensor of shape [1, num_devices], but line 864-865 compares len(outputs) (which is 1, the batch dimension) against len(class_labels) (which is the number of device labels, typically > 1). This assertion will fail. The subsequent zip on line 867 expects to iterate over num_devices predictions, not one batch-dimension element.

Additionally, line 862 creates a Data object with num_gates=number_of_gates, but PyTorch Geometric's Data class expects num_nodes as the standard attribute name. While custom attributes are permitted, using the standard name is more reliable.

Fix by removing the batch dimension from the model output:

outputs = gnn_model(feature_vector)
outputs = outputs.squeeze(0)  # Remove batch dimension [1, num_devices] -> [num_devices]
assert class_labels is not None
if len(class_labels) != len(outputs):
    msg = "outputs and class_labels must be same length"
    raise ValueError(msg)
pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)

Also update line 862 to use the standard PyTorch Geometric attribute:

feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 862-869, the GNN is called
with a single Data object so model returns shape [1, num_devices] and the code
incorrectly compares len(outputs) (==1) to len(class_labels) (==num_devices) and
zips them, causing a mismatch; fix by removing the batch dimension from outputs
(e.g., outputs = outputs.squeeze(0) or outputs = outputs[0]) before length
checks and zipping, and also replace the custom Data attribute num_gates with
the standard PyTorch Geometric attribute num_nodes when constructing the Data
object.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (2)
src/mqt/predictor/ml/predictor.py (2)

601-701: Critical: Multiple issues will cause runtime failures in GNN prediction.

Several previously flagged issues remain unresolved:

  1. Line 607: Dead code — list comprehension result is discarded:

    [dev.description for dev in self.devices]  # Result not assigned
  2. Class labels not persisted: json_dict stores only hyperparameters, but predict_device_for_figure_of_merit (line 860) expects gnn_model.classes which is never set. You must save the device labels:

     json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
    +json_dict["classes"] = [dev.description for dev in self.devices]
  3. Line 698: Verbose evaluation always calls evaluate_classification_model, even when task="regression" (Hellinger distance). This produces meaningless metrics:

     if verbose:
         test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False)
    -    avg_loss_test, dict_results, _ = evaluate_classification_model(
    -        model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
    -    )
    +    if task == "regression":
    +        avg_loss_test, dict_results, _ = evaluate_regression_model(
    +            model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
    +        )
    +    else:
    +        avg_loss_test, dict_results, _ = evaluate_classification_model(
    +            model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
    +        )
         print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")

503-506: Critical: num_outputs parameter is unused and k_folds may be invalid.

Two issues persist from previous reviews:

  1. Line 503: The num_outputs parameter is immediately overwritten, making it a dead parameter. Either remove it from the signature or honor the passed value.

  2. Line 506: KFold(n_splits=k_folds) will raise ValueError if k_folds < 2. When len(training_data.y_train) == 1, k_folds becomes 1.

-        num_outputs = max(1, len(self.devices))
+        # Use provided num_outputs, or fall back to device count
+        num_outputs = num_outputs if num_outputs > 0 else max(1, len(self.devices))

         # Split into k-folds
+        k_folds = max(2, k_folds)  # KFold requires n_splits >= 2
         kf = KFold(n_splits=k_folds, shuffle=True)

Alternatively, skip hyperparameter search entirely for datasets with fewer than 2 samples.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2e28a14 and e9f698e.

📒 Files selected for processing (1)
  • src/mqt/predictor/ml/predictor.py (12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/mqt/predictor/ml/predictor.py (4)
src/mqt/predictor/ml/gnn.py (1)
  • GNN (191-295)
src/mqt/predictor/hellinger/utils.py (1)
  • get_hellinger_model_path (135-144)
src/mqt/predictor/ml/helper.py (10)
  • TrainingData (569-579)
  • create_dag (192-301)
  • create_feature_vector (164-189)
  • evaluate_classification_model (323-390)
  • evaluate_regression_model (393-459)
  • get_openqasm3_gates (119-151)
  • get_path_trained_model (48-50)
  • get_path_trained_model_gnn (53-57)
  • get_path_training_data (36-38)
  • train_model (462-565)
src/mqt/predictor/rl/predictor.py (1)
  • train_model (87-127)
🪛 Ruff (0.14.5)
src/mqt/predictor/ml/predictor.py

107-107: Boolean-typed positional argument in function definition

(FBT001)


107-107: Boolean default positional argument in function definition

(FBT002)


146-146: Logging statement uses f-string

(G004)


149-149: Logging statement uses f-string

(G004)


173-173: Boolean-typed positional argument in function definition

(FBT001)


173-173: Boolean default positional argument in function definition

(FBT002)


448-448: Boolean-typed positional argument in function definition

(FBT001)


448-448: Boolean default positional argument in function definition

(FBT002)


571-571: Boolean-typed positional argument in function definition

(FBT001)


571-571: Boolean default positional argument in function definition

(FBT002)


802-802: Boolean-typed positional argument in function definition

(FBT001)


802-802: Boolean default positional argument in function definition

(FBT002)

⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
  • GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
  • GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (3)
src/mqt/predictor/ml/predictor.py (3)

83-98: Type definitions look good.

The GNNSample, FeatureSample, TrainingSample type aliases and TrainGNNKwargs TypedDict provide clear typing for the GNN data flow.


322-353: Training data generation logic is correct.

The GNN branch properly creates Data objects with multi-device scores as targets and preserves target_label for later use. The conditional serialization to .pt vs .npy is appropriate.


769-771: Security note: weights_only=False allows arbitrary code execution.

Using torch.load(..., weights_only=False) is necessary to deserialize torch_geometric.data.Data objects, but it can execute arbitrary Python code if the .pt file is malicious. This is acceptable for training data generated by the pipeline itself, but ensure this file is never loaded from untrusted sources.

Consider adding a comment documenting this security consideration.

Comment on lines +107 to +108
gnn: bool = False,
**gnn_kwargs: Unpack[TrainGNNKwargs],
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider keyword-only for gnn parameter.

While functional, making gnn keyword-only (after *) would prevent accidental positional usage and address the Ruff FBT001/FBT002 hints consistently across the codebase.

 def setup_device_predictor(
     devices: list[Target],
     figure_of_merit: figure_of_merit = "expected_fidelity",
     path_uncompiled_circuits: Path | None = None,
     path_compiled_circuits: Path | None = None,
     path_training_data: Path | None = None,
     timeout: int = 600,
+    *,
     gnn: bool = False,
     **gnn_kwargs: Unpack[TrainGNNKwargs],
 ) -> bool:
🧰 Tools
🪛 Ruff (0.14.5)

107-107: Boolean-typed positional argument in function definition

(FBT001)


107-107: Boolean default positional argument in function definition

(FBT002)

🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 107-108, make the gnn
parameter keyword-only to avoid accidental positional use and satisfy Ruff
FBT001/FBT002: change the function signature so gnn appears after a bare *
(e.g., add a * before gnn) so callers must pass gnn by name, keep its type and
default (gnn: bool = False) and keep **gnn_kwargs unchanged; then update any
local call sites that passed gnn positionally to use the keyword form and run
linters/tests to ensure no breakage.

Comment on lines +421 to +432
scores_list = scores # list(scores.values())
target_label = max(scores, key=lambda k: scores[k])

qc = QuantumCircuit.from_qasm_file(path_uncompiled_circuit / file)
feature_vec = create_feature_vector(qc)
training_sample = (feature_vec, target_label)
training_sample: TrainingSample
if self.gnn:
x, edge_index, number_of_gates = create_dag(qc)
y = torch.tensor([[dev.description for dev in self.devices].index(target_label)], dtype=torch.float)
training_sample = (x, y, edge_index, number_of_gates, target_label)
else:
feature_vec = create_feature_vector(qc)
training_sample = (feature_vec, target_label)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Clarify or remove the intermediate y tensor that gets discarded.

The y tensor created at line 428 (device index) is immediately discarded in generate_training_data (line 328: x, _y, edge_idx, ...) and replaced with the full device scores tensor. This intermediate value adds confusion without serving a purpose.

Consider either:

  1. Not computing y here and returning a 4-tuple (x, edge_index, number_of_gates, target_label), or
  2. Adding a comment explaining why this value exists but is unused.
         if self.gnn:
             x, edge_index, number_of_gates = create_dag(qc)
-            y = torch.tensor([[dev.description for dev in self.devices].index(target_label)], dtype=torch.float)
-            training_sample = (x, y, edge_index, number_of_gates, target_label)
+            # Note: y placeholder - actual targets are computed in generate_training_data
+            training_sample = (x, None, edge_index, number_of_gates, target_label)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 421 to 432, the code
constructs a temporary tensor y (device index) that is never used downstream and
is discarded by generate_training_data; remove the creation of y and change the
training_sample for the GNN branch to a 4-tuple (x, edge_index, number_of_gates,
target_label) to match the rest of the pipeline, or if you must keep it for
clarity/compatibility add a one-line comment explaining it is unused and only
created for legacy reasons; ensure any call sites or tuple unpacking
expectations are updated accordingly to avoid breaking consumers.

Comment on lines +674 to +688
x_train, x_val, _y_train, _y_val = train_test_split(
training_data.X_train, training_data.y_train, test_size=0.2, random_state=5
)
# Dataloader
train_loader = DataLoader(x_train, batch_size=16, shuffle=True)

val_loader = DataLoader(x_val, batch_size=16, shuffle=False)
train_model(
model,
train_loader,
optimizer,
loss_fn,
task=task,
num_epochs=num_epochs,
device=device,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Minor: Double train/val split and incorrect device type.

  1. Lines 674-676: Data is split again (80/20) after _get_prepared_training_data already performed a 70/30 split. This results in effective train/val/test proportions of 56%/14%/30%. If intentional, consider documenting; otherwise, consider using training_data.X_test as validation set.

  2. Line 688: train_model expects device: str | None, but a torch.device object is passed. This works due to duck typing but is technically a type mismatch. Consider passing str(device) or "cuda" if torch.cuda.is_available() else "cpu".

🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 674 to 688, remove the
redundant train/val split (currently splitting training_data.X_train again) and
instead use the prepared validation set (e.g., training_data.X_test) or
explicitly document why a second split is required; update the DataLoader
instantiation to use the chosen train and val arrays accordingly. Also change
the device argument passed to train_model from a torch.device object to a string
(e.g., str(device) or a conditional "cuda"/"cpu" value) so the parameter matches
the expected type.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/device_selection/test_predictor_ml.py (1)

75-78: Missing gnn parameter in prediction call will cause test failure for GNN cases.

When the test is parameterized with gnn=True, the training creates a GNN model (.pth file), but predict_device_for_figure_of_merit is called without gnn=True. This causes the prediction to look for a non-existent Random Forest model (.joblib file), resulting in FileNotFoundError.

     test_qc = get_benchmark("ghz", BenchmarkLevel.ALG, 3)
-    predicted = predict_device_for_figure_of_merit(test_qc, figure_of_merit="expected_fidelity")
+    predicted = predict_device_for_figure_of_merit(test_qc, figure_of_merit="expected_fidelity", gnn=gnn)
♻️ Duplicate comments (10)
src/mqt/predictor/ml/predictor.py (8)

673-679: Double train/validation split reduces effective training data.

_get_prepared_training_data already performs a 70/30 train/test split. Lines 673-675 split again with 80/20, resulting in effective proportions of 56%/14%/30%. Consider using training_data.X_test as validation set if this is unintentional.


628-632: Critical: mlp_units not stored in JSON; prediction will fail with KeyError.

Line 845 reads json_dict["mlp_units"], but this key is never added to json_dict. The code computes mlp_units from mlp_str but doesn't store it. Additionally, class labels should be stored for GNN prediction.

         json_dict = study.best_trial.params
         mlp_str = dict_best_hyper["mlp"]
         mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")]

         json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
+        json_dict["mlp_units"] = mlp_units
+        json_dict["classes"] = [dev.description for dev in self.devices]

633-662: Reduce code duplication in GNN instantiation.

The two GNN construction blocks are nearly identical, differing only in output_dim. Consolidate using a computed output_dim.

-        if self.figure_of_merit != "hellinger_distance":
-            model = GNN(
-                in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
-                # ... many identical lines ...
-                output_dim=len(self.devices),
-                # ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
-        else:
-            model = GNN(
-                # ... identical lines ...
-                output_dim=1,
-                # ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
+        output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices)
+        model = GNN(
+            in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
+            num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"],
+            hidden_dim=dict_best_hyper["hidden_dim"],
+            num_resnet_layers=dict_best_hyper["num_resnet_layers"],
+            mlp_units=mlp_units,
+            output_dim=output_dim,
+            dropout_p=dict_best_hyper["dropout"],
+            bidirectional=dict_best_hyper["bidirectional"],
+            use_sag_pool=dict_best_hyper["sag_pool"],
+            sag_ratio=0.7,
+            conv_activation=torch.nn.functional.leaky_relu,
+            mlp_activation=torch.nn.functional.leaky_relu,
+        ).to("cuda" if torch.cuda.is_available() else "cpu")

694-699: Verbose test block ignores task variable and always uses classification evaluator.

When task="regression" (for Hellinger distance), calling evaluate_classification_model produces incorrect metrics. Branch on task to call the appropriate evaluator.

         if verbose:
             test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False)
-            avg_loss_test, dict_results, _ = evaluate_classification_model(
-                model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
-            )
+            if task == "regression":
+                avg_loss_test, dict_results, _ = evaluate_regression_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
+            else:
+                avg_loss_test, dict_results, _ = evaluate_classification_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
             print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")

858-867: Critical: gnn_model.classes doesn't exist and output shape mismatch will cause failures.

Multiple issues will cause runtime failures:

  1. Line 858: gnn_model.classes is never defined on the GNN class and is not set during training. This raises AttributeError.

  2. Lines 862-864: For a single Data object, gnn_model() returns tensor shape [1, num_devices]. len(outputs) returns 1 (batch dimension), but len(class_labels) would be num_devices. The assertion always fails.

  3. Line 866: Even if the assertion passed, outputs.tolist() on [1, num_devices] produces [[score1, score2, ...]], so zip would only iterate once.

Fix by loading class labels from JSON (after storing them during training) and squeezing the output:

-        class_labels = gnn_model.classes
+        class_labels = json_dict["classes"]  # Must be saved during training
         with torch.no_grad():
             outputs = gnn_model(feature_vector)
+        outputs = outputs.squeeze(0)  # Remove batch dimension: [1, N] -> [N]
         assert class_labels is not None
         if len(class_labels) != len(outputs):
             msg = "outputs and class_labels must be same length"
             raise ValueError(msg)

426-429: Intermediate y tensor computed but immediately discarded.

The y tensor at line 428 is created but discarded in generate_training_data (line 328 unpacks with _y). The actual targets are computed from scores in generate_training_data. Consider removing this redundant computation or documenting why it exists.

         if self.gnn:
             x, edge_index, number_of_gates = create_dag(qc)
-            y = torch.tensor([[dev.description for dev in self.devices].index(target_label)], dtype=torch.float)
-            training_sample = (x, y, edge_index, number_of_gates, target_label)
+            # y placeholder - actual targets are computed in generate_training_data from scores
+            training_sample = (x, None, edge_index, number_of_gates, target_label)

505-506: Guard against KFold with invalid n_splits for tiny datasets.

If k_folds (derived from min(len(y_train), 5)) is less than 2, KFold(n_splits=k_folds) raises ValueError. For very small datasets, either skip cross-validation or clamp k_folds.

+        k_folds = max(2, k_folds)  # KFold requires at least 2 splits
         kf = KFold(n_splits=k_folds, shuffle=True)

606-606: Dead expression: result of list comprehension is discarded.

This line creates a list of device descriptions but never assigns it. If it was intended to store class labels for prediction, assign it to a variable and include it in json_dict.

-            [dev.description for dev in self.devices]
+            # If needed for later use, assign the result:
+            # device_labels = [dev.description for dev in self.devices]

Or remove the line entirely if not needed.

src/mqt/predictor/ml/helper.py (2)

304-317: get_results_classes silently corrupts target labels for [N,1] shaped tensors.

After targets.unsqueeze(1) at line 361, targets become shape [N, 1]. When argmax(targets, dim=1) is applied to [N, 1] tensors, it always returns 0 (since there's only one element along dim 1), regardless of the actual target values. This corrupts accuracy calculations.

Apply this fix to handle both one-hot and single-column targets:

 def get_results_classes(preds: torch.Tensor, targets: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
     """Return predicted and target class indices.

     Arguments:
         preds: model predictions
-        targets: ground truth targets
+        targets: ground truth targets (one-hot encoded or integer class labels)
     Returns:
         pred_idx: predicted class indices
         targets_idx: target class indices
     """
     pred_idx = torch.argmax(preds, dim=1)
-    targets_idx = torch.argmax(targets, dim=1)
+    if targets.dim() == 1:
+        targets_idx = targets.long()
+    elif targets.size(1) == 1:
+        targets_idx = targets.view(-1).long()
+    else:
+        targets_idx = torch.argmax(targets, dim=1)

     return pred_idx, targets_idx

331-331: Type annotation mismatch: metrics dict stores str but annotated as dict[str, float].

Line 389 assigns classification_report_res (a string) to metrics["classification_report"], but the return type annotation at line 331 specifies dict[str, float].

-) -> tuple[float, dict[str, float], tuple[np.ndarray, np.ndarray] | None]:
+) -> tuple[float, dict[str, float | str], tuple[np.ndarray, np.ndarray] | None]:

And update the initialization:

-    metrics = {"loss": float(avg_loss)}
+    metrics: dict[str, float | str] = {"loss": float(avg_loss)}

Also applies to: 388-389

📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e9f698e and 8b79177.

📒 Files selected for processing (3)
  • src/mqt/predictor/ml/helper.py (4 hunks)
  • src/mqt/predictor/ml/predictor.py (12 hunks)
  • tests/device_selection/test_predictor_ml.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/device_selection/test_predictor_ml.py (3)
src/mqt/predictor/ml/helper.py (1)
  • get_path_training_data (36-38)
src/mqt/predictor/rl/helper.py (1)
  • get_path_training_data (97-99)
src/mqt/predictor/ml/predictor.py (1)
  • Predictor (166-796)
src/mqt/predictor/ml/helper.py (1)
src/mqt/predictor/utils.py (1)
  • calc_supermarq_features (94-146)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py

42-42: Boolean-typed positional argument in function definition

(FBT001)


42-42: Boolean-typed positional argument in function definition

(FBT001)


113-113: Boolean-typed positional argument in function definition

(FBT001)

src/mqt/predictor/ml/predictor.py

107-107: Boolean-typed positional argument in function definition

(FBT001)


107-107: Boolean default positional argument in function definition

(FBT002)


146-146: Logging statement uses f-string

(G004)


149-149: Logging statement uses f-string

(G004)


173-173: Boolean-typed positional argument in function definition

(FBT001)


173-173: Boolean default positional argument in function definition

(FBT002)


448-448: Boolean-typed positional argument in function definition

(FBT001)


448-448: Boolean default positional argument in function definition

(FBT002)


570-570: Boolean-typed positional argument in function definition

(FBT001)


570-570: Boolean default positional argument in function definition

(FBT002)


800-800: Boolean-typed positional argument in function definition

(FBT001)


800-800: Boolean default positional argument in function definition

(FBT002)

⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
  • GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
  • GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (13)
tests/device_selection/test_predictor_ml.py (4)

38-43: LGTM!

The parameterized test setup correctly covers both Random Forest and GNN paths with appropriate test IDs.


68-73: LGTM!

The conditional assertions correctly verify the expected training artifacts for each model type.


95-99: LGTM!

The cleanup logic correctly handles both .npy and .pt file extensions.


112-117: LGTM!

Appropriate parameterization for testing both GNN and non-GNN paths in _get_prepared_training_data.

src/mqt/predictor/ml/helper.py (6)

53-57: LGTM!

The helper correctly returns the path to the GNN checkpoint file using f-string formatting.


119-151: LGTM!

The OpenQASM 3.0 gate list includes the spec URL reference and version note for future maintenance.


192-301: LGTM!

The DAG construction correctly handles edge cases with empty edges and empty topo_nodes, returning appropriate tensors.


405-469: LGTM!

The regression evaluator correctly handles empty loaders and provides appropriate metrics.


472-571: LGTM!

The training loop correctly handles early stopping, validation, and best model restoration.


574-585: LGTM!

The TrainingData dataclass appropriately supports both classical (numpy) and GNN (graph) model data types.

src/mqt/predictor/ml/predictor.py (3)

92-98: LGTM!

The TrainGNNKwargs TypedDict provides clear type hints for GNN training arguments.


100-163: LGTM!

The setup_device_predictor function correctly routes to GNN or Random Forest training based on the gnn flag.


327-341: LGTM!

The GNN training sample construction correctly creates Data objects with appropriate attributes.

).to("cuda" if torch.cuda.is_available() else "cpu")
gnn_model.load_state_dict(torch.load(path))
x, edge_index, number_of_gates = create_dag(qc)
feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use standard PyTorch Geometric attribute num_nodes instead of custom num_gates.

PyTorch Geometric's Data class expects num_nodes as the standard attribute. Using num_gates may cause issues with some GNN operations that rely on the standard attribute.

-        feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates)
+        feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around line 856, the Data object is created
with a nonstandard attribute num_gates; replace it with the PyTorch Geometric
standard num_nodes by setting num_nodes=number_of_gates (and if backward
compatibility is needed you can also keep num_gates=number_of_gates), so change
the Data(...) call to use num_nodes instead of num_gates.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (8)
src/mqt/predictor/ml/helper.py (2)

304-317: get_results_classes silently corrupts targets when passed integer class labels.

This function always applies argmax(dim=1) on targets, which only works correctly for one-hot encoded targets. When targets are 1D integer labels or shape [N, 1], argmax returns all zeros, corrupting accuracy metrics. Per evaluate_classification_model at line 360-361, targets are unsqueezed to [N, 1] when 1D, which will then fail here.

Apply this diff to handle both encodings:

 def get_results_classes(preds: torch.Tensor, targets: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
     """Return predicted and target class indices.

     Arguments:
         preds: model predictions
-        targets: ground truth targets
+        targets: ground truth targets (one-hot encoded or integer class labels)
     Returns:
         pred_idx: predicted class indices
         targets_idx: target class indices
     """
     pred_idx = torch.argmax(preds, dim=1)
-    targets_idx = torch.argmax(targets, dim=1)
+    if targets.dim() == 1:
+        targets_idx = targets.long()
+    elif targets.size(1) == 1:
+        targets_idx = targets.view(-1).long()
+    else:
+        targets_idx = torch.argmax(targets, dim=1)

     return pred_idx, targets_idx

323-331: Return type annotation is too narrow for the metrics dict.

The function is annotated as returning dict[str, float], but metrics["classification_report"] (line 389) is assigned a string from classification_report(). This will cause type checker warnings.

Apply this diff:

-) -> tuple[float, dict[str, float], tuple[np.ndarray, np.ndarray] | None]:
+) -> tuple[float, dict[str, float | str], tuple[np.ndarray, np.ndarray] | None]:

And update line 375:

-    metrics = {"loss": float(avg_loss)}
+    metrics: dict[str, float | str] = {"loss": float(avg_loss)}
src/mqt/predictor/ml/predictor.py (6)

421-434: Intermediate y tensor is computed but discarded.

The y tensor (device index) at line 428 is immediately discarded in generate_training_data (line 328: x, _y, edge_idx, ...) and replaced with the full device scores tensor. This intermediate computation serves no purpose and adds confusion.


500-506: num_outputs parameter is ignored; k_folds can cause KFold error.

Two issues:

  1. Line 503 overwrites the num_outputs parameter with max(1, len(self.devices)), making the function argument useless.

  2. Line 506 creates KFold(n_splits=k_folds), but k_folds comes from min(len(y_train), 5) at line 609. If dataset has only 1 sample, k_folds=1 which is invalid for KFold.

Apply this diff:

-        num_outputs = max(1, len(self.devices))
+        # Honor the parameter if provided, otherwise use device count
+        num_outputs = max(1, num_outputs) if num_outputs > 0 else max(1, len(self.devices))

         # Split into k-folds
+        k_folds = max(2, k_folds)  # KFold requires at least 2 splits
         kf = KFold(n_splits=k_folds, shuffle=True)

626-666: Critical: mlp_units and classes not stored in JSON; prediction will fail.

Three issues will cause KeyError or AttributeError at prediction time:

  1. Line 628-630: json_dict is populated from study.best_trial.params which stores "mlp" as a string (e.g., "64,32"), but prediction at line 845 reads json_dict["mlp_units"] expecting a list.

  2. Line 632: num_outputs is stored but classes (device labels) are not. Line 858 tries to access gnn_model.classes which doesn't exist.

  3. Lines 633-662: Duplicate GNN instantiation can be consolidated.

Apply this diff to store required metadata:

         json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
+        json_dict["mlp_units"] = mlp_units  # Store parsed list for inference
+        json_dict["classes"] = [dev.description for dev in self.devices]  # Store device labels
+
+        # Consolidate model instantiation
+        output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices)
+        model = GNN(
+            in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
+            num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"],
+            hidden_dim=dict_best_hyper["hidden_dim"],
+            num_resnet_layers=dict_best_hyper["num_resnet_layers"],
+            mlp_units=mlp_units,
+            output_dim=output_dim,
+            dropout_p=dict_best_hyper["dropout"],
+            bidirectional=dict_best_hyper["bidirectional"],
+            use_sag_pool=dict_best_hyper["sag_pool"],
+            sag_ratio=0.7,
+            conv_activation=torch.nn.functional.leaky_relu,
+            mlp_activation=torch.nn.functional.leaky_relu,
+        ).to("cuda" if torch.cuda.is_available() else "cpu")
-        if self.figure_of_merit != "hellinger_distance":
-            model = GNN(
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
-        else:
-            model = GNN(
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")

687-699: Device type mismatch and wrong evaluator for regression tasks.

Two issues:

  1. Line 687: train_model expects device: str | None, but device is a torch.device object. This works due to duck typing but is a type mismatch.

  2. Lines 694-699: The verbose block always calls evaluate_classification_model, even when task="regression" (Hellinger distance). This produces meaningless metrics for regression.

Apply this diff:

         train_model(
             model,
             train_loader,
             optimizer,
             loss_fn,
             task=task,
             num_epochs=num_epochs,
-            device=device,
+            device=str(device),
             verbose=verbose,
             val_loader=val_loader,
             patience=30,
             min_delta=0.0,
             restore_best=True,
         )
         if verbose:
             test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False)
-            avg_loss_test, dict_results, _ = evaluate_classification_model(
-                model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
-            )
+            if task == "regression":
+                avg_loss_test, dict_results, _ = evaluate_regression_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
+            else:
+                avg_loss_test, dict_results, _ = evaluate_classification_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
             print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")

835-867: Critical: Multiple bugs will cause GNN prediction to fail at runtime.

Four issues in the GNN prediction path:

  1. Line 845: json_dict["mlp_units"] will raise KeyError since training saves "mlp" as a string, not "mlp_units" as a list.

  2. Line 856: Uses num_gates but PyTorch Geometric expects num_nodes. Inconsistent with training code at line 335.

  3. Line 858: gnn_model.classes is never defined on the GNN class. This will raise AttributeError.

  4. Lines 860-864: Output shape mismatch. The model returns shape [1, num_devices] for a single Data object, but len(outputs) returns 1 (batch dimension), not num_devices. The assertion will fail.

Apply this diff:

         gnn_model.load_state_dict(torch.load(path))
         x, edge_index, number_of_gates = create_dag(qc)
-        feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates)
+        feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)
         gnn_model.eval()
-        class_labels = gnn_model.classes
+        class_labels = json_dict["classes"]  # Load from JSON metadata
         with torch.no_grad():
             outputs = gnn_model(feature_vector)
+        outputs = outputs.squeeze(0)  # Remove batch dimension [1, num_devices] -> [num_devices]
         assert class_labels is not None
         if len(class_labels) != len(outputs):
             msg = "outputs and class_labels must be same length"
             raise ValueError(msg)

         pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)

Also fix the model instantiation to parse mlp from JSON:

         gnn_model = GNN(
             in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
             num_conv_wo_resnet=json_dict["num_conv_wo_resnet"],
             hidden_dim=json_dict["hidden_dim"],
             num_resnet_layers=json_dict["num_resnet_layers"],
-            mlp_units=json_dict["mlp_units"],
+            mlp_units=json_dict.get("mlp_units") or ([] if json_dict["mlp"] == "none" else [int(x) for x in json_dict["mlp"].split(",")]),
             output_dim=json_dict["num_outputs"],

606-606: Dead code: list comprehension result is discarded.

Line 606 computes [dev.description for dev in self.devices] but the result is not assigned to any variable. This appears to be leftover code.

Remove the dead code:

-            [dev.description for dev in self.devices]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e9f698e and 8b79177.

📒 Files selected for processing (3)
  • src/mqt/predictor/ml/helper.py (4 hunks)
  • src/mqt/predictor/ml/predictor.py (12 hunks)
  • tests/device_selection/test_predictor_ml.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/mqt/predictor/ml/helper.py (1)
src/mqt/predictor/utils.py (1)
  • calc_supermarq_features (94-146)
tests/device_selection/test_predictor_ml.py (3)
src/mqt/predictor/ml/helper.py (1)
  • get_path_training_data (36-38)
src/mqt/predictor/rl/helper.py (1)
  • get_path_training_data (97-99)
src/mqt/predictor/ml/predictor.py (1)
  • Predictor (166-796)
🪛 Ruff (0.14.6)
src/mqt/predictor/ml/predictor.py

107-107: Boolean-typed positional argument in function definition

(FBT001)


107-107: Boolean default positional argument in function definition

(FBT002)


146-146: Logging statement uses f-string

(G004)


149-149: Logging statement uses f-string

(G004)


173-173: Boolean-typed positional argument in function definition

(FBT001)


173-173: Boolean default positional argument in function definition

(FBT002)


448-448: Boolean-typed positional argument in function definition

(FBT001)


448-448: Boolean default positional argument in function definition

(FBT002)


570-570: Boolean-typed positional argument in function definition

(FBT001)


570-570: Boolean default positional argument in function definition

(FBT002)


800-800: Boolean-typed positional argument in function definition

(FBT001)


800-800: Boolean default positional argument in function definition

(FBT002)

tests/device_selection/test_predictor_ml.py

42-42: Boolean-typed positional argument in function definition

(FBT001)


42-42: Boolean-typed positional argument in function definition

(FBT001)


113-113: Boolean-typed positional argument in function definition

(FBT001)

⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
  • GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
  • GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (14)
tests/device_selection/test_predictor_ml.py (4)

38-43: Parameterized test setup looks correct.

The test parameterization covers the three key scenarios: RF-only, GNN-only, and GNN with verbose output. This provides good coverage for the dual training paths.


57-64: Test invocation correctly passes gnn and verbose parameters.

The setup_device_predictor call properly forwards the gnn and verbose parameters to test both training paths.


68-73: Conditional assertions correctly validate expected artifacts.

The test properly checks for graph_dataset_expected_fidelity.pt when gnn=True and the numpy-based files when gnn=False. This aligns with the dual data storage approach.


112-117: Parameterized test correctly covers both gnn and rf paths for error handling.

The test now validates that FileNotFoundError is raised for both the RF and GNN paths when training data is missing.

src/mqt/predictor/ml/helper.py (6)

53-58: GNN checkpoint path helper looks correct.

The helper properly constructs the path to the GNN checkpoint file using an f-string as suggested in prior review.


119-152: OpenQASM 3.0 gate list is well-documented.

The gate list includes a reference to the OpenQASM 3.0 specification and a note about verification. This is helpful for future maintenance.


192-301: DAG featurization handles edge cases correctly.

The create_dag function now includes proper guards for:

  • Empty edges (returns torch.empty((2, 0), dtype=torch.long))
  • Empty topo_nodes (returns early with zero critical flags)

The feature layout is well-documented in the docstring.


405-469: Regression evaluator is well-structured.

The function correctly handles edge cases (empty predictions, constant targets for R²), uses proper metric computation, and has consistent verbose output.


472-572: Training loop with early stopping is correctly implemented.

The train_model function properly:

  • Handles both classification and regression tasks
  • Implements early stopping with patience and min_delta
  • Restores best model weights when restore_best=True
  • Transfers model to the correct device

574-585: TrainingData dataclass correctly supports both numpy and graph data.

The widened type hints (NDArray | list[Data] for X, NDArray | Tensor for y) properly accommodate both classical and GNN training paths.

src/mqt/predictor/ml/predictor.py (4)

18-63: Imports are well-organized for the new GNN functionality.

The new imports for torch, optuna, torch_geometric, and related modules are properly structured with TYPE_CHECKING guards where appropriate.


83-98: Type definitions clearly document the sample formats.

The GNNSample, FeatureSample, and TrainingSample type aliases make the dual data paths explicit. The TrainGNNKwargs TypedDict is a clean way to forward GNN-specific arguments.


100-163: Setup function correctly routes to GNN or RF training paths.

The setup_device_predictor function properly:

  • Accepts gnn flag and gnn_kwargs
  • Routes to train_gnn_model when gnn=True
  • Maintains backward compatibility with RF path

327-341: GNN training sample creation correctly populates Data object.

The code properly creates a torch_geometric.data.Data object with node features, edge index, and target labels from the score dictionary.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/device_selection/test_predictor_ml.py (1)

75-78: Prediction test always uses RF path regardless of gnn parameter.

Line 76 calls predict_device_for_figure_of_merit without passing gnn=gnn. When gnn=True, the test trains a GNN model but then attempts prediction with the RF model (which may not exist), causing a mismatch.

     test_qc = get_benchmark("ghz", BenchmarkLevel.ALG, 3)
-    predicted = predict_device_for_figure_of_merit(test_qc, figure_of_merit="expected_fidelity")
+    predicted = predict_device_for_figure_of_merit(test_qc, figure_of_merit="expected_fidelity", gnn=gnn)

     assert predicted.description == "ibm_falcon_127"
♻️ Duplicate comments (10)
src/mqt/predictor/ml/predictor.py (9)

107-108: Consider making gnn keyword-only to satisfy Ruff FBT001/FBT002.

Boolean positional arguments can be confusing at call sites. Adding *, before gnn forces callers to use the keyword form.

 def setup_device_predictor(
     devices: list[Target],
     figure_of_merit: figure_of_merit = "expected_fidelity",
     path_uncompiled_circuits: Path | None = None,
     path_compiled_circuits: Path | None = None,
     path_training_data: Path | None = None,
     timeout: int = 600,
+    *,
     gnn: bool = False,
     **gnn_kwargs: Unpack[TrainGNNKwargs],
 ) -> bool:

421-429: The intermediate y tensor (line 428) is computed but discarded.

The y tensor created here is unpacked as _y and ignored in generate_training_data (line 328), then replaced with a different tensor. This computation is wasted.

Also, line 421 assigns a dict to scores_list, which is misleading naming.

-        scores_list = scores  # list(scores.values())
+        scores_dict = scores
         target_label = max(scores, key=lambda k: scores[k])

         qc = QuantumCircuit.from_qasm_file(path_uncompiled_circuit / file)
         training_sample: TrainingSample
         if self.gnn:
             x, edge_index, number_of_gates = create_dag(qc)
-            y = torch.tensor([[dev.description for dev in self.devices].index(target_label)], dtype=torch.float)
-            training_sample = (x, y, edge_index, number_of_gates, target_label)
+            training_sample = (x, None, edge_index, number_of_gates, target_label)

503-506: num_outputs parameter is overwritten; k_folds may be invalid for tiny datasets.

  1. Line 503 overwrites the num_outputs parameter passed to the function, making it effectively unused. Either remove it from the signature or honor the passed value.

  2. Line 506 creates KFold(n_splits=k_folds), but k_folds comes from min(len(training_data.y_train), 5) in train_gnn_model (line 609), which could be 1 for a single sample. KFold requires n_splits >= 2.

-        num_outputs = max(1, len(self.devices))
+        # Use passed num_outputs or fall back to device count
+        num_outputs = num_outputs if num_outputs > 0 else max(1, len(self.devices))

         # Split into k-folds
+        k_folds = max(2, k_folds)  # KFold requires at least 2 splits
         kf = KFold(n_splits=k_folds, shuffle=True)

626-632: Critical: mlp_units and classes not stored in JSON metadata; prediction will fail.

The json_dict is populated from study.best_trial.params, which contains "mlp" as a string (e.g., "64,32"). However, predict_device_for_figure_of_merit at line 845 expects json_dict["mlp_units"] as a list, which will raise KeyError.

Additionally, class labels must be stored for inference since gnn_model.classes (line 858) is never defined on the GNN class.

         json_dict = study.best_trial.params
         mlp_str = dict_best_hyper["mlp"]
         mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")]

         json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
+        json_dict["mlp_units"] = mlp_units  # Store parsed list for inference
+        json_dict["classes"] = [dev.description for dev in self.devices]  # Store class labels

633-662: Reduce duplication in GNN model construction.

The two branches are nearly identical, differing only in output_dim. Consolidate:

-        if self.figure_of_merit != "hellinger_distance":
-            model = GNN(
-                in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
-                ...
-                output_dim=len(self.devices),
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
-        else:
-            model = GNN(
-                ...
-                output_dim=1,
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
+        output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices)
+        device_str = "cuda" if torch.cuda.is_available() else "cpu"
+        model = GNN(
+            in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
+            num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"],
+            hidden_dim=dict_best_hyper["hidden_dim"],
+            num_resnet_layers=dict_best_hyper["num_resnet_layers"],
+            mlp_units=mlp_units,
+            output_dim=output_dim,
+            dropout_p=dict_best_hyper["dropout"],
+            bidirectional=dict_best_hyper["bidirectional"],
+            use_sag_pool=dict_best_hyper["sag_pool"],
+            sag_ratio=0.7,
+            conv_activation=torch.nn.functional.leaky_relu,
+            mlp_activation=torch.nn.functional.leaky_relu,
+        ).to(device_str)

687-687: Type mismatch: train_model expects device: str | None, but receives torch.device.

Line 687 passes device (a torch.device object) to train_model, but the function signature expects str | None. This works due to duck typing but is a type annotation mismatch.

-            device=device,
+            device=str(device),

694-699: Verbose evaluation ignores task type; always uses classification evaluator.

For "hellinger_distance" or "estimated_hellinger_distance", task is set to "regression" (line 601), but the verbose block unconditionally calls evaluate_classification_model. This produces incorrect metrics for regression tasks.

         if verbose:
             test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False)
-            avg_loss_test, dict_results, _ = evaluate_classification_model(
-                model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
-            )
+            if task == "regression":
+                avg_loss_test, dict_results, _ = evaluate_regression_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
+            else:
+                avg_loss_test, dict_results, _ = evaluate_classification_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
             print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")

840-858: Critical: Multiple bugs will cause GNN prediction to fail at runtime.

  1. Line 845: json_dict["mlp_units"] will raise KeyError. The training saves "mlp" (a string), not "mlp_units" (a list).

  2. Line 854: torch.load(path) without weights_only=True raises a security warning in PyTorch 2.x.

  3. Line 856: Uses num_gates=number_of_gates, but Data objects created during training use num_nodes. While functionally equivalent, the inconsistent naming is confusing.

  4. Line 858: gnn_model.classes is never defined on the GNN class. This will raise AttributeError.

         gnn_model = GNN(
             in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
             num_conv_wo_resnet=json_dict["num_conv_wo_resnet"],
             hidden_dim=json_dict["hidden_dim"],
             num_resnet_layers=json_dict["num_resnet_layers"],
-            mlp_units=json_dict["mlp_units"],
+            mlp_units=json_dict.get("mlp_units") or ([] if json_dict["mlp"] == "none" else [int(x) for x in json_dict["mlp"].split(",")]),
             output_dim=json_dict["num_outputs"],
             dropout_p=json_dict["dropout"],
             bidirectional=json_dict["bidirectional"],
             use_sag_pool=json_dict["sag_pool"],
             sag_ratio=0.7,
             conv_activation=torch.nn.functional.leaky_relu,
             mlp_activation=torch.nn.functional.leaky_relu,
         ).to("cuda" if torch.cuda.is_available() else "cpu")
-        gnn_model.load_state_dict(torch.load(path))
+        gnn_model.load_state_dict(torch.load(path, weights_only=True))
         x, edge_index, number_of_gates = create_dag(qc)
-        feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates)
+        feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)
         gnn_model.eval()
-        class_labels = gnn_model.classes
+        class_labels = json_dict.get("classes")  # Load from JSON metadata

859-867: Output shape mismatch: len(outputs) returns batch dimension, not num_classes.

The GNN model returns a tensor of shape [1, num_devices] for a single Data object. len(outputs) returns 1 (the batch dimension), not the number of device scores. The comparison at line 862 will always fail unless there's exactly one device.

Additionally, outputs.tolist() on shape [1, N] returns [[score1, score2, ...]], causing the zip to iterate only once.

         with torch.no_grad():
             outputs = gnn_model(feature_vector)
+        outputs = outputs.squeeze(0)  # Remove batch dimension: [1, num_devices] -> [num_devices]
         assert class_labels is not None
         if len(class_labels) != len(outputs):
             msg = "outputs and class_labels must be same length"
             raise ValueError(msg)

         pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)
src/mqt/predictor/ml/helper.py (1)

304-317: get_results_classes corrupts target indices for integer labels.

When targets are 1D or [N, 1] shaped (integer class indices), argmax(targets, dim=1) returns all zeros regardless of actual values. In evaluate_classification_model, targets are unsqueezed to [N, 1] at lines 360-361, triggering this bug.

Make the function robust to both encodings:

 def get_results_classes(preds: torch.Tensor, targets: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
     """Return predicted and target class indices.

     Arguments:
         preds: model predictions
-        targets: ground truth targets
+        targets: ground truth targets (one-hot encoded or integer class labels)
     Returns:
         pred_idx: predicted class indices
         targets_idx: target class indices
     """
     pred_idx = torch.argmax(preds, dim=1)
-    targets_idx = torch.argmax(targets, dim=1)
+    if targets.dim() == 1:
+        targets_idx = targets.long()
+    elif targets.size(1) == 1:
+        targets_idx = targets.view(-1).long()
+    else:
+        targets_idx = torch.argmax(targets, dim=1)

     return pred_idx, targets_idx
📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e9f698e and 8b79177.

📒 Files selected for processing (3)
  • src/mqt/predictor/ml/helper.py (4 hunks)
  • src/mqt/predictor/ml/predictor.py (12 hunks)
  • tests/device_selection/test_predictor_ml.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/device_selection/test_predictor_ml.py (2)
src/mqt/predictor/ml/helper.py (1)
  • get_path_training_data (36-38)
src/mqt/predictor/rl/helper.py (1)
  • get_path_training_data (97-99)
src/mqt/predictor/ml/predictor.py (4)
src/mqt/predictor/ml/gnn.py (1)
  • GNN (191-295)
src/mqt/predictor/hellinger/utils.py (1)
  • get_hellinger_model_path (135-144)
src/mqt/predictor/ml/helper.py (12)
  • TrainingData (575-585)
  • create_dag (192-301)
  • create_feature_vector (164-189)
  • evaluate_classification_model (323-402)
  • evaluate_regression_model (405-469)
  • get_openqasm3_gates (119-151)
  • get_path_trained_model (48-50)
  • get_path_trained_model_gnn (53-57)
  • get_path_training_circuits (60-62)
  • get_path_training_circuits_compiled (65-67)
  • get_path_training_data (36-38)
  • train_model (472-571)
src/mqt/predictor/rl/helper.py (3)
  • get_path_trained_model (102-104)
  • get_path_training_circuits (107-109)
  • get_path_training_data (97-99)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py

42-42: Boolean-typed positional argument in function definition

(FBT001)


42-42: Boolean-typed positional argument in function definition

(FBT001)


113-113: Boolean-typed positional argument in function definition

(FBT001)

src/mqt/predictor/ml/predictor.py

107-107: Boolean-typed positional argument in function definition

(FBT001)


107-107: Boolean default positional argument in function definition

(FBT002)


146-146: Logging statement uses f-string

(G004)


149-149: Logging statement uses f-string

(G004)


173-173: Boolean-typed positional argument in function definition

(FBT001)


173-173: Boolean default positional argument in function definition

(FBT002)


448-448: Boolean-typed positional argument in function definition

(FBT001)


448-448: Boolean default positional argument in function definition

(FBT002)


570-570: Boolean-typed positional argument in function definition

(FBT001)


570-570: Boolean default positional argument in function definition

(FBT002)


800-800: Boolean-typed positional argument in function definition

(FBT001)


800-800: Boolean default positional argument in function definition

(FBT002)

⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
  • GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
  • GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (11)
src/mqt/predictor/ml/predictor.py (3)

18-86: LGTM on imports and type definitions.

The new imports for GNN support (torch, torch_geometric, optuna, gc, json) and the type aliases (GNNSample, FeatureSample, TrainingSample) are well-organized and correctly typed.


92-98: LGTM on TrainGNNKwargs TypedDict.

Good use of TypedDict with total=False to make all GNN-specific arguments optional. This cleanly separates GNN hyperparameters from the main function signature.


762-781: LGTM on conditional data loading for GNN vs RF paths.

The branching logic correctly handles both numpy arrays for RF and PyTorch tensors for GNN, with appropriate file formats (.npy vs .pt) and data extraction methods.

tests/device_selection/test_predictor_ml.py (3)

38-43: Good test parameterization for RF and GNN paths.

The parameterization with ids=["rf", "gnn", "gnn_verbose"] provides clear test identification and covers both model types plus verbose output.


96-99: LGTM on cleanup handling both file formats.

The cleanup correctly removes both .npy and .pt files to handle artifacts from both RF and GNN test runs.


112-117: LGTM on parameterized test for _get_prepared_training_data.

The test correctly parameterizes the gnn flag and verifies that FileNotFoundError is raised when training data is missing for both paths.

src/mqt/predictor/ml/helper.py (5)

53-57: LGTM on get_path_trained_model_gnn helper.

Clean implementation using f-string and clear docstring describing it returns a checkpoint file path.


119-151: LGTM on OpenQASM 3.0 gate list.

Good documentation with spec reference and version snapshot note for future maintenance.


192-301: LGTM on create_dag implementation.

Robust handling of edge cases:

  • Empty edges correctly produce a (2, 0) tensor (lines 265-268)
  • Empty topo_nodes returns early with zero critical flags (lines 272-276)
  • Unknown gates raise a clear ValueError (lines 247-249)

The feature vector layout is well-documented in the docstring.


405-469: LGTM on regression evaluator.

Clean implementation with proper handling of edge cases (empty results, constant targets for R² calculation). The verbose output provides useful diagnostic information.


574-585: LGTM on updated TrainingData dataclass.

The widened type annotations correctly support both numpy arrays for RF and torch_geometric.data.Data lists for GNN. The docstring clarifies the dual-mode support.

Comment on lines +323 to +402
def evaluate_classification_model(
model: nn.Module,
loader: torch_geometric.loader.DataLoader,
loss_fn: nn.Module,
device: str,
*,
return_arrays: bool = False,
verbose: bool = False,
) -> tuple[float, dict[str, float], tuple[np.ndarray, np.ndarray] | None]:
"""Evaluate a classification model with the given loss function and compute accuracy metrics.
Arguments:
model: classification model to be evaluated
loader: data loader for the evaluation dataset
loss_fn: loss function for evaluation
device: device to be used for evaluation (cuda or cpu)
return_arrays: whether to return prediction and target arrays
verbose: whether to print the metrics results.
Returns:
avg_loss: average loss over the loader
metrics: {"custom_accuracy": ..., "classification_report": ..., "mse": ..., "rmse": ..., "mae": ..., "r2": ...}
arrays: (preds, y_true) if return_arrays=True, else None.
"""
device = torch.device(device)

model.eval()
total_loss, total = 0.0, 0
all_preds, all_targets = [], []

with torch.no_grad():
for batch in loader:
batch_device = batch.to(device)
preds = model(batch_device)
preds = torch.clamp(preds, 0.0, 1.0)
targets = batch_device.y.float()

if targets.dim() == 1:
targets = targets.unsqueeze(1)
if preds.shape != targets.shape:
msg = f"Shape mismatch: preds {preds.shape} vs targets {targets.shape}"
raise ValueError(msg)

bs = targets.size(0)
loss = loss_fn(preds, targets)
total_loss += loss.item() * bs
total += bs

all_preds.append(preds.detach().cpu())
all_targets.append(targets.detach().cpu())

avg_loss = total_loss / max(1, total)
metrics = {"loss": float(avg_loss)}

if not all_preds or not all_targets:
arrays = (np.array([]), np.array([])) if return_arrays else None
return avg_loss, metrics, arrays

preds = torch.cat(all_preds, dim=0)
targets = torch.cat(all_targets, dim=0)

# --- compute accuracy ---
pred_classes, target_classes = get_results_classes(preds, targets)
acc = accuracy_score(target_classes, pred_classes)
classification_report_res = classification_report(target_classes, pred_classes)
metrics["custom_accuracy"] = float(acc)
metrics["classification_report"] = classification_report_res

if verbose:
mse = mean_squared_error(targets.numpy().reshape(-1), preds.numpy().reshape(-1))
mae = mean_absolute_error(targets.numpy().reshape(-1), preds.numpy().reshape(-1))
rmse = float(np.sqrt(mse))
if targets.size(0) < 2 or torch.all(targets == targets[0]):
r2 = float("nan")
else:
r2 = float(r2_score(targets.numpy().reshape(-1), preds.numpy().reshape(-1)))
metrics.update({"mse": float(mse), "rmse": float(rmse), "mae": float(mae), "r2": float(r2)})

arrays = (preds.numpy(), targets.numpy()) if return_arrays else None
return avg_loss, metrics, arrays
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

LGTM on classification evaluator with empty loader guard.

The empty loader check (lines 377-379) prevents torch.cat crash on empty lists. The evaluation logic correctly computes accuracy and classification report.

One minor type annotation issue: the return type annotation is dict[str, float], but metrics["classification_report"] is a string. Consider widening to dict[str, float | str].

🤖 Prompt for AI Agents
In src/mqt/predictor/ml/helper.py around lines 323 to 402, the return type
annotation currently declares the metrics dict as dict[str, float] but the
function stores a classification report string under "classification_report";
change the return type to reflect both numeric and string values (for example
dict[str, float | str] or dict[str, Any]) and update the function signature
annotation accordingly so the declared return type matches the actual metrics
contents.

Comment on lines +534 to +542
if val_loader is not None:
if task == "classification":
val_loss, val_metrics, _ = evaluate_classification_model(
model, val_loader, loss_fn, device=str(device), verbose=True
)
elif task == "regression":
val_loss, val_metrics, _ = evaluate_regression_model(
model, val_loader, loss_fn, device=str(device), verbose=True
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Validation evaluation always uses verbose=True, causing noisy output.

Lines 537 and 541 hardcode verbose=True for validation evaluation, which prints metrics every epoch regardless of the outer verbose parameter. This may produce excessive output during training.

             if task == "classification":
                 val_loss, val_metrics, _ = evaluate_classification_model(
-                    model, val_loader, loss_fn, device=str(device), verbose=True
+                    model, val_loader, loss_fn, device=str(device), verbose=False
                 )
             elif task == "regression":
                 val_loss, val_metrics, _ = evaluate_regression_model(
-                    model, val_loader, loss_fn, device=str(device), verbose=True
+                    model, val_loader, loss_fn, device=str(device), verbose=False
                 )
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/helper.py around lines 534 to 542, the validation
evaluation calls hardcode verbose=True which forces per-epoch printing; change
both evaluate_classification_model and evaluate_regression_model calls to pass
the surrounding function's verbose parameter (e.g., verbose=verbose) so
validation logging respects the outer verbosity flag (or explicitly pass False
if you want to silence validation always).

Comment on lines +327 to 341
if self.gnn:
x, _y, edge_idx, n_nodes, target_label = training_sample
name_device = sorted(scores.keys())
value_device = [scores[i] for i in name_device]
gnn_training_sample = Data(
x=x,
y=torch.tensor(value_device, dtype=torch.float32),
edge_index=edge_idx,
num_nodes=n_nodes,
target_label=target_label,
)

training_data.append(gnn_training_sample if self.gnn else training_sample)
names_list.append(circuit_name)
scores_list.append(scores)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Device score ordering may diverge from self.devices order.

Line 329 uses sorted(scores.keys()) for device ordering, but self.devices is sorted by dev.description at initialization (line 189-191). If the sorted order of scores.keys() differs from the sorted order of device descriptions, the y tensor indices won't match the model's expected output ordering.

Consider using the canonical self.devices ordering:

             if self.gnn:
                 x, _y, edge_idx, n_nodes, target_label = training_sample
-                name_device = sorted(scores.keys())
-                value_device = [scores[i] for i in name_device]
+                # Use canonical device ordering from self.devices
+                value_device = [scores.get(dev.description, -1.0) for dev in self.devices]
                 gnn_training_sample = Data(
                     x=x,
                     y=torch.tensor(value_device, dtype=torch.float32),

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (8)
src/mqt/predictor/ml/predictor.py (6)

633-662: Consolidate duplicate GNN instantiation.

The two branches (lines 634-647 and 649-662) are nearly identical, differing only in output_dim. This duplication can be eliminated.

-        if self.figure_of_merit != "hellinger_distance":
-            model = GNN(
-                in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
-                ...
-                output_dim=len(self.devices),
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
-        else:
-            model = GNN(
-                ...
-                output_dim=1,
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
+        output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices)
+        device_str = "cuda" if torch.cuda.is_available() else "cpu"
+        model = GNN(
+            in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
+            num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"],
+            hidden_dim=dict_best_hyper["hidden_dim"],
+            num_resnet_layers=dict_best_hyper["num_resnet_layers"],
+            mlp_units=mlp_units,
+            output_dim=output_dim,
+            dropout_p=dict_best_hyper["dropout"],
+            bidirectional=dict_best_hyper["bidirectional"],
+            use_sag_pool=dict_best_hyper["sag_pool"],
+            sag_ratio=0.7,
+            conv_activation=torch.nn.functional.leaky_relu,
+            mlp_activation=torch.nn.functional.leaky_relu,
+        ).to(device_str)

673-687: Double train/val split and type mismatch in train_model call.

  1. Double split: _get_prepared_training_data already performs a 70/30 train/test split (line 784). Here, lines 673-675 split again (80/20), resulting in effective proportions of ~56%/14%/30%. If intentional, document this; otherwise use training_data.X_test as the validation set.

  2. Type mismatch: Line 687 passes device (a torch.device object) to train_model, which expects device: str | None. While this works due to duck typing, it's inconsistent.

-        device=device,
+        device=str(device),

854-867: Multiple critical bugs will cause GNN prediction to fail at runtime.

  1. Line 856: Uses num_gates but PyTorch Geometric's Data class expects num_nodes. Use the standard attribute name for consistency:

    -feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates)
    +feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)
  2. Line 858: gnn_model.classes is never defined on the GNN class and is never set during training. This will raise AttributeError. Load class labels from JSON metadata instead (after adding json_dict["classes"] during training as noted above):

    -class_labels = gnn_model.classes
    +class_labels = json_dict["classes"]
  3. Lines 862-864: The length comparison is incorrect. outputs has shape [1, num_devices] (batch dimension from single Data object), so len(outputs) returns 1, not the number of devices. The assertion will always fail unless there's exactly one device. Fix by squeezing the batch dimension:

    +outputs = outputs.squeeze(0)  # Remove batch dim: [1, num_devices] -> [num_devices]
     assert class_labels is not None
     if len(class_labels) != len(outputs):

628-632: mlp_units not stored in JSON; inference will fail with KeyError.

Line 845 in predict_device_for_figure_of_merit reads json_dict["mlp_units"], but here only the raw mlp string is saved (via study.best_trial.params). The parsed mlp_units list must be stored explicitly.

         mlp_str = dict_best_hyper["mlp"]
         mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")]

         json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
+        json_dict["mlp_units"] = mlp_units

503-506: num_outputs parameter is overwritten; k_folds may be invalid for tiny datasets.

  1. Line 503 overwrites num_outputs with max(1, len(self.devices)), making the function parameter effectively unused. Either remove the parameter or honor it.

  2. Line 506 creates KFold(n_splits=k_folds, ...) but k_folds can be 1 when len(y_train) == 1, which is invalid for KFold. Guard against this:

+        k_folds = max(2, k_folds)  # KFold requires at least 2 splits
         kf = KFold(n_splits=k_folds, shuffle=True)

And either remove the num_outputs parameter from the signature or use it:

-        num_outputs = max(1, len(self.devices))
+        if num_outputs is None or num_outputs < 1:
+            num_outputs = max(1, len(self.devices))

694-699: Verbose test block always calls classification evaluator regardless of task.

When task == "regression" (for Hellinger distance), lines 696-698 still call evaluate_classification_model, which computes accuracy metrics that are meaningless for regression. Branch on the task variable.

         if verbose:
             test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False)
-            avg_loss_test, dict_results, _ = evaluate_classification_model(
-                model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
-            )
+            if task == "regression":
+                avg_loss_test, dict_results, _ = evaluate_regression_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
+            else:
+                avg_loss_test, dict_results, _ = evaluate_classification_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
             print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")
src/mqt/predictor/ml/helper.py (2)

304-317: get_results_classes will corrupt target labels when targets are shaped [N, 1].

This function applies argmax(targets, dim=1) unconditionally. When targets is shaped [N, 1] (as happens after targets.unsqueeze(1) at line 361), argmax along dim=1 always returns 0 regardless of the actual class index. This causes accuracy metrics to be computed incorrectly.

Apply this diff to handle both one-hot and integer label encodings:

 def get_results_classes(preds: torch.Tensor, targets: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
     """Return predicted and target class indices.

     Arguments:
         preds: model predictions
-        targets: ground truth targets
+        targets: ground truth targets (one-hot encoded or integer class labels)
     Returns:
         pred_idx: predicted class indices
         targets_idx: target class indices
     """
     pred_idx = torch.argmax(preds, dim=1)
-    targets_idx = torch.argmax(targets, dim=1)
+    if targets.dim() == 1:
+        targets_idx = targets.long()
+    elif targets.size(1) == 1:
+        targets_idx = targets.view(-1).long()
+    else:
+        targets_idx = torch.argmax(targets, dim=1)

     return pred_idx, targets_idx

323-331: Return type annotation is too narrow for metrics dict.

The function is annotated as returning dict[str, float], but metrics["classification_report"] at line 389 is assigned a string from classification_report(). This type mismatch will cause static type checkers to flag issues.

-) -> tuple[float, dict[str, float], tuple[np.ndarray, np.ndarray] | None]:
+) -> tuple[float, dict[str, float | str], tuple[np.ndarray, np.ndarray] | None]:

And at line 375:

-    metrics = {"loss": float(avg_loss)}
+    metrics: dict[str, float | str] = {"loss": float(avg_loss)}
📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e9f698e and 8b79177.

📒 Files selected for processing (3)
  • src/mqt/predictor/ml/helper.py (4 hunks)
  • src/mqt/predictor/ml/predictor.py (12 hunks)
  • tests/device_selection/test_predictor_ml.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/mqt/predictor/ml/predictor.py (6)
src/mqt/predictor/ml/gnn.py (1)
  • GNN (191-295)
src/mqt/predictor/hellinger/utils.py (1)
  • get_hellinger_model_path (135-144)
src/mqt/predictor/ml/helper.py (9)
  • TrainingData (575-585)
  • create_dag (192-301)
  • create_feature_vector (164-189)
  • evaluate_classification_model (323-402)
  • evaluate_regression_model (405-469)
  • get_openqasm3_gates (119-151)
  • get_path_trained_model (48-50)
  • get_path_trained_model_gnn (53-57)
  • train_model (472-571)
tests/device_selection/test_predictor_ml.py (2)
  • path_uncompiled_circuits (27-29)
  • path_compiled_circuits (33-35)
tests/hellinger_distance/test_estimated_hellinger_distance.py (1)
  • device (49-51)
tests/compilation/test_reward.py (1)
  • device (38-40)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py

42-42: Boolean-typed positional argument in function definition

(FBT001)


42-42: Boolean-typed positional argument in function definition

(FBT001)


113-113: Boolean-typed positional argument in function definition

(FBT001)

src/mqt/predictor/ml/predictor.py

107-107: Boolean-typed positional argument in function definition

(FBT001)


107-107: Boolean default positional argument in function definition

(FBT002)


146-146: Logging statement uses f-string

(G004)


149-149: Logging statement uses f-string

(G004)


173-173: Boolean-typed positional argument in function definition

(FBT001)


173-173: Boolean default positional argument in function definition

(FBT002)


448-448: Boolean-typed positional argument in function definition

(FBT001)


448-448: Boolean default positional argument in function definition

(FBT002)


570-570: Boolean-typed positional argument in function definition

(FBT001)


570-570: Boolean default positional argument in function definition

(FBT002)


800-800: Boolean-typed positional argument in function definition

(FBT001)


800-800: Boolean default positional argument in function definition

(FBT002)

⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
  • GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
  • GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (12)
tests/device_selection/test_predictor_ml.py (3)

38-43: Test parameterization covers both RF and GNN paths appropriately.

The parameterization includes gnn=False (RF), gnn=True, and gnn=True with verbose=True, providing good coverage for both training paths. The test IDs (rf, gnn, gnn_verbose) are clear and descriptive.


68-73: Artifact assertions correctly validate the expected output format per mode.

The conditional checks appropriately verify:

  • GNN mode produces graph_dataset_*.pt
  • RF mode produces training_data_*.npy, names_list_*.npy, and scores_list_*.npy

This aligns with the generate_training_data implementation in predictor.py.


112-115: LGTM!

The parameterization correctly tests both RF and GNN paths for the _get_prepared_training_data failure case.

src/mqt/predictor/ml/helper.py (6)

53-57: LGTM!

The function correctly returns the path to the GNN checkpoint file using an f-string as previously suggested. The docstring accurately describes the return value.


119-151: LGTM!

The OpenQASM 3.0 gate list includes the verification note and reference to the specification. The gates match the standard library documented at the referenced URL.


192-301: DAG featurization is well-structured with appropriate edge-case handling.

The implementation correctly handles:

  • Empty edge sets (lines 265-268)
  • Empty topo_nodes (lines 272-276)
  • Unknown gates with clear error messages (lines 247-249)

The feature layout is comprehensive: one-hot gate encoding, sin/cos parameters, arity, controls, num_params, critical path flag, and fan-out proportion.


405-469: LGTM!

The regression evaluator correctly handles empty loaders, computes standard metrics (RMSE, MAE, R²), and guards against degenerate cases (constant targets, insufficient samples) by returning NaN.


472-572: LGTM!

The training loop is well-structured with:

  • Early stopping based on validation loss with patience and min_delta
  • Best model state restoration
  • Proper device handling
  • Task-specific evaluation branching

574-585: LGTM!

The TrainingData dataclass correctly documents support for both classical (numpy) and GNN (graph) models with appropriately widened type annotations.

src/mqt/predictor/ml/predictor.py (3)

83-97: LGTM!

The type aliases and TrainGNNKwargs TypedDict are well-defined and provide clear type information for the GNN training pipeline.


326-341: LGTM!

The GNN training sample construction correctly:

  • Extracts node features and edges from the DAG
  • Creates device score tensor from sorted device names
  • Uses num_nodes (not num_gates) for the Data object

752-796: LGTM!

The _get_prepared_training_data method correctly handles both RF (numpy arrays) and GNN (torch Data objects) paths, with appropriate file loading and data preparation logic.

Comment on lines +421 to 422
scores_list = scores # list(scores.values())
target_label = max(scores, key=lambda k: scores[k])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Remove or clarify the commented assignment.

Line 421 assigns scores_list = scores with a commented-out alternative # list(scores.values()). This suggests uncertainty about the intended value. Since scores is already a dict and is used directly, the comment should be removed to avoid confusion.

-        scores_list = scores  # list(scores.values())
+        scores_list = scores
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 421 to 422, remove the
ambiguous commented assignment and clarify the intended variable: replace or
delete the line assigning scores_list = scores with a single clear statement
(either remove scores_list entirely if unused, or set scores_list =
list(scores.values()) if a list of values is required) and remove the commented
alternative so the code reflects the actual intended type and usage.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (14)
tests/device_selection/test_predictor_ml.py (2)

68-73: GNN path assertions are incomplete.

The GNN path only asserts graph_dataset_expected_fidelity.pt exists, but according to the predictor code (lines 350-353), names_list_*.npy and scores_list_*.npy are saved regardless of the gnn flag. Add these assertions for the GNN branch to ensure complete validation.

     if gnn:
         assert (data_path / "graph_dataset_expected_fidelity.pt").exists()
+        assert (data_path / "names_list_expected_fidelity.npy").exists()
+        assert (data_path / "scores_list_expected_fidelity.npy").exists()
     else:

98-98: Use in operator for cleaner suffix check.

-            if file.suffix == ".npy" or file.suffix == ".pt":
+            if file.suffix in (".npy", ".pt"):
src/mqt/predictor/ml/predictor.py (12)

107-108: Consider keyword-only for gnn parameter.

Making gnn keyword-only (after *) would prevent accidental positional usage and address the Ruff FBT001/FBT002 hints.


327-341: Device score ordering may diverge from self.devices order.

Line 329 uses sorted(scores.keys()) for device ordering, but self.devices is sorted by dev.description at initialization. If scores.keys() contains different keys or sorts differently than device descriptions, the y tensor indices won't match the model's expected output ordering.

Use the canonical self.devices ordering:

             if self.gnn:
                 x, _y, edge_idx, n_nodes, target_label = training_sample
-                name_device = sorted(scores.keys())
-                value_device = [scores[i] for i in name_device]
+                # Use canonical device ordering from self.devices
+                value_device = [scores.get(dev.description, -1.0) for dev in self.devices]
                 gnn_training_sample = Data(

421-422: Remove commented alternative.

The commented-out # list(scores.values()) adds confusion. Since scores is used directly as a dict, remove the comment.


426-429: The intermediate y tensor is computed but discarded.

Line 428 computes y = torch.tensor(...) but it's immediately discarded in generate_training_data (line 328: x, _y, edge_idx, ...). Either remove this computation or document why it exists.


606-606: Dead code: list comprehension result not assigned.

Line 606 computes [dev.description for dev in self.devices] but doesn't assign the result. This should store device labels in json_dict["classes"] for use during inference.

-            [dev.description for dev in self.devices]
+            json_dict["classes"] = [dev.description for dev in self.devices]

628-636: mlp_units not stored in JSON metadata; inference will fail.

Line 849 in predict_device_for_figure_of_merit reads json_dict["mlp_units"], but the JSON only contains "mlp" (a string). Store the parsed list:

         json_dict["num_outputs"] = (
             len(self.devices)
             if self.figure_of_merit != "hellinger_distance" and self.figure_of_merit != "estimated_hellinger_distance"
             else 1
         )
+        json_dict["mlp_units"] = mlp_units  # Store parsed list for inference

637-666: Reduce code duplication in GNN instantiation.

The two GNN instantiation blocks differ only in output_dim. Consolidate them.

+        output_dim = 1 if self.figure_of_merit in ("hellinger_distance", "estimated_hellinger_distance") else len(self.devices)
+        device_str = "cuda" if torch.cuda.is_available() else "cpu"
-        if self.figure_of_merit != "hellinger_distance" and self.figure_of_merit != "estimated_hellinger_distance":
-            model = GNN(
-                ...
-                output_dim=len(self.devices),
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
-        else:
-            model = GNN(
-                ...
-                output_dim=1,
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
+        model = GNN(
+            in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
+            num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"],
+            hidden_dim=dict_best_hyper["hidden_dim"],
+            num_resnet_layers=dict_best_hyper["num_resnet_layers"],
+            mlp_units=mlp_units,
+            output_dim=output_dim,
+            dropout_p=dict_best_hyper["dropout"],
+            bidirectional=dict_best_hyper["bidirectional"],
+            use_sag_pool=dict_best_hyper["sag_pool"],
+            sag_ratio=0.7,
+            conv_activation=torch.nn.functional.leaky_relu,
+            mlp_activation=torch.nn.functional.leaky_relu,
+        ).to(device_str)

677-691: Double train/val split and incorrect device type.

  1. Lines 677-679: Data is split again (80/20) after _get_prepared_training_data already performed a 70/30 split. This results in effective train/val/test proportions of 56%/14%/30%. If intentional, document it; otherwise use training_data.X_test as validation.

  2. Line 691: train_model expects device: str | None, but a torch.device object is passed. Pass str(device) or the string directly.


860-860: Use standard PyTorch Geometric attribute num_nodes instead of num_gates.

PyTorch Geometric's Data class expects num_nodes as the standard attribute. Using num_gates may cause issues with GNN operations.

-        feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates)
+        feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)

862-871: Critical: GNN prediction will fail at runtime due to multiple issues.

  1. Line 862: gnn_model.classes is never defined on the GNN class or set during training. This will raise AttributeError. Load from JSON metadata instead (after fixing training to save it).

  2. Lines 864-868: Shape mismatch. The model returns tensor of shape [1, num_devices] for a single Data object. len(outputs) returns 1 (batch dimension), not num_devices, so this assertion will always fail for multi-device scenarios.

  3. Line 870: outputs.tolist() with shape [1, N] becomes [[score1, score2, ...]]. Zipping this with class_labels iterates only once.

Fix by loading classes from JSON and squeezing the output:

-        class_labels = gnn_model.classes
+        class_labels = json_dict["classes"]  # Must be saved during training
         with torch.no_grad():
             outputs = gnn_model(feature_vector)
+        outputs = outputs.squeeze(0)  # Remove batch dimension [1, num_devices] -> [num_devices]
         assert class_labels is not None
         if len(class_labels) != len(outputs):
             msg = "outputs and class_labels must be same length"
             raise ValueError(msg)

         pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)

698-703: Verbose evaluation ignores task type.

The verbose block unconditionally calls evaluate_classification_model, but for "hellinger_distance" or "estimated_hellinger_distance", the task is "regression". Branch on task:

         if verbose:
             test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False)
-            avg_loss_test, dict_results, _ = evaluate_classification_model(
-                model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
-            )
+            if task == "regression":
+                avg_loss_test, dict_results, _ = evaluate_regression_model(
+                    model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
+                )
+            else:
+                avg_loss_test, dict_results, _ = evaluate_classification_model(
+                    model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
+                )
             print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")

503-506: num_outputs parameter is ignored; k_folds may be invalid for tiny datasets.

  1. Line 503 overwrites num_outputs with max(1, len(self.devices)), making the parameter useless. Either remove it from the signature or honor the passed value.

  2. KFold(n_splits=k_folds) with k_folds=1 (possible if len(training_data.y_train) == 1) will raise an error. Guard against this:

-        num_outputs = max(1, len(self.devices))
+        num_outputs = num_outputs if num_outputs > 0 else max(1, len(self.devices))

         # Split into k-folds
+        k_folds = max(2, k_folds)  # KFold requires at least 2 splits
         kf = KFold(n_splits=k_folds, shuffle=True)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8b79177 and 77cca31.

📒 Files selected for processing (2)
  • src/mqt/predictor/ml/predictor.py (13 hunks)
  • tests/device_selection/test_predictor_ml.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/device_selection/test_predictor_ml.py (3)
src/mqt/predictor/ml/helper.py (1)
  • get_path_training_data (36-38)
src/mqt/predictor/rl/helper.py (1)
  • get_path_training_data (97-99)
src/mqt/predictor/ml/predictor.py (1)
  • predict_device_for_figure_of_merit (803-878)
src/mqt/predictor/ml/predictor.py (2)
src/mqt/predictor/ml/gnn.py (1)
  • GNN (191-295)
src/mqt/predictor/ml/helper.py (12)
  • TrainingData (575-585)
  • create_dag (192-301)
  • create_feature_vector (164-189)
  • evaluate_classification_model (323-402)
  • evaluate_regression_model (405-469)
  • get_openqasm3_gates (119-151)
  • get_path_trained_model (48-50)
  • get_path_trained_model_gnn (53-57)
  • get_path_training_circuits (60-62)
  • get_path_training_circuits_compiled (65-67)
  • get_path_training_data (36-38)
  • train_model (472-571)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py

42-42: Boolean-typed positional argument in function definition

(FBT001)


42-42: Boolean-typed positional argument in function definition

(FBT001)


113-113: Boolean-typed positional argument in function definition

(FBT001)

src/mqt/predictor/ml/predictor.py

107-107: Boolean-typed positional argument in function definition

(FBT001)


107-107: Boolean default positional argument in function definition

(FBT002)


146-146: Logging statement uses f-string

(G004)


149-149: Logging statement uses f-string

(G004)


173-173: Boolean-typed positional argument in function definition

(FBT001)


173-173: Boolean default positional argument in function definition

(FBT002)


448-448: Boolean-typed positional argument in function definition

(FBT001)


448-448: Boolean default positional argument in function definition

(FBT002)


570-570: Boolean-typed positional argument in function definition

(FBT001)


570-570: Boolean default positional argument in function definition

(FBT002)


804-804: Boolean-typed positional argument in function definition

(FBT001)


804-804: Boolean default positional argument in function definition

(FBT002)

⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
  • GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
🔇 Additional comments (6)
tests/device_selection/test_predictor_ml.py (2)

38-43: Test parametrization covers key GNN paths.

The parametrization with (False, False), (True, False), (True, True) appropriately tests RF mode, GNN mode, and GNN with verbose output. Consider adding (False, True) to also test RF with verbose mode for completeness, though this is optional.


112-116: Test coverage for GNN prepared training data path is good.

The parametrization ensures both RF and GNN paths are tested for the _get_prepared_training_data error handling. The test correctly verifies that FileNotFoundError is raised when training data is missing for both model types.

src/mqt/predictor/ml/predictor.py (4)

92-98: Well-defined TypedDict for GNN kwargs.

The TrainGNNKwargs TypedDict with total=False cleanly captures optional GNN training parameters. This enables type-safe forwarding via **gnn_kwargs.


144-149: Conditional training path logic is correct.

The branching between train_random_forest_model() and train_gnn_model(**gnn_kwargs) properly routes based on the gnn flag, with appropriate logging for each path.


766-773: GNN training data loading logic is correct.

The conditional loading of .pt vs .npy files based on self.gnn flag properly handles both data formats. The weights_only=False for torch.load is necessary here since the file contains Data objects, not just tensors.


858-858: Use weights_only=True when loading torch checkpoints for security.

PyTorch's official documentation confirms that torch.load() without weights_only=True permits arbitrary code execution from untrusted model files through pickle deserialization. For loading state_dict (plain tensor state), weights_only=True is the recommended secure approach. PyTorch changed the default to weights_only=True starting in version 2.6 due to these security concerns.

-        gnn_model.load_state_dict(torch.load(path))
+        gnn_model.load_state_dict(torch.load(path, weights_only=True))

Comment on lines 584 to 593
# Figure out outputs and save path
if self.figure_of_merit == "hellinger_distance" or self.figure_of_merit == "estimated_hellinger_distance":
if len(self.devices) != 1:
msg = "A single device must be provided for Hellinger distance model training."
raise ValueError(msg)
num_outputs = 1
save_mdl_path = str(get_hellinger_model_path(self.devices[0], gnn=True))
else:
num_outputs = max(1, len(self.devices))
save_mdl_path = str(get_path_trained_model_gnn(self.figure_of_merit))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Hellinger distance figure_of_merit check can be simplified.

The repeated or checks for "hellinger_distance" and "estimated_hellinger_distance" appear multiple times. Consider using a tuple membership test for clarity:

-        if self.figure_of_merit == "hellinger_distance" or self.figure_of_merit == "estimated_hellinger_distance":
+        if self.figure_of_merit in ("hellinger_distance", "estimated_hellinger_distance"):

This pattern appears at lines 585, 600, 632-634, 637, and 734. A constant or helper could further reduce duplication.

🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 584 to 593, the conditional
that checks for "hellinger_distance" or "estimated_hellinger_distance" should be
simplified and deduplicated: replace the repeated `== "hellinger_distance" or ==
"estimated_hellinger_distance"` checks with a membership test like `if
self.figure_of_merit in ("hellinger_distance", "estimated_hellinger_distance"):`
and add a small module-level constant or helper function (e.g.,
HELLINGER_FIGURES = ("hellinger_distance", "estimated_hellinger_distance") or
def is_hellinger(f): return f in HELLINGER_FIGURES) then use that
constant/helper to replace the same pattern at the other occurrences you noted
(lines ~585, 600, 632-634, 637, and 734) so the logic is identical but not
duplicated.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/hellinger_distance/test_estimated_hellinger_distance.py (1)

254-265: Windows timeout warning is incorrectly gated by gnn flag.

The warning should appear regardless of whether GNN is used. The condition sys.platform == "win32" and not gnn means the warning is only expected for RF on Windows, but the timeout limitation applies to both paths:

-        if sys.platform == "win32" and not gnn:
+        if sys.platform == "win32":

If the GNN path intentionally doesn't use timeout, add a comment explaining why.

♻️ Duplicate comments (15)
src/mqt/predictor/ml/predictor.py (12)

604-606: Dead code: list comprehension result is discarded.

Line 606 computes [dev.description for dev in self.devices] but does not assign the result. This appears to be an incomplete attempt to store class labels in json_dict. Either remove this line or assign it appropriately.

-            [dev.description for dev in self.devices]
+            # Class labels are stored below in json_dict["class_labels"]

Or simply delete line 606 since line 665 already stores the class labels correctly.


633-662: Duplicate GNN instantiation code can be consolidated.

The two branches differ only in output_dim. This was flagged before and marked as addressed, but the duplication remains.

-        if self.figure_of_merit != "hellinger_distance":
-            model = GNN(
-                in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
-                ...
-                output_dim=len(self.devices),
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
-        else:
-            model = GNN(
-                ...
-                output_dim=1,
-                ...
-            ).to("cuda" if torch.cuda.is_available() else "cpu")
+        output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices)
+        device_str = "cuda" if torch.cuda.is_available() else "cpu"
+        model = GNN(
+            in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1),
+            num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"],
+            hidden_dim=dict_best_hyper["hidden_dim"],
+            num_resnet_layers=dict_best_hyper["num_resnet_layers"],
+            mlp_units=mlp_units,
+            output_dim=output_dim,
+            dropout_p=dict_best_hyper["dropout"],
+            bidirectional=dict_best_hyper["bidirectional"],
+            use_sag_pool=dict_best_hyper["sag_pool"],
+            sag_ratio=0.7,
+            conv_activation=torch.nn.functional.leaky_relu,
+            mlp_activation=torch.nn.functional.leaky_relu,
+        ).to(device_str)

674-676: Double train/val split results in unexpected data proportions.

_get_prepared_training_data() already performs a 70/30 train/test split. Line 674-676 then splits the training set again 80/20, resulting in effective proportions of ~56%/14%/30% (train/val/test). If intentional, document it; otherwise, consider using training_data.X_test as the validation set.


688-688: Type mismatch: train_model expects device: str | None but receives torch.device.

Line 688 passes a torch.device object where a string is expected. While this works due to duck typing, it's technically incorrect. Pass a string instead:

-            device=device,
+            device=str(device),

695-700: Verbose evaluation always uses classification metrics, even for regression tasks.

When task == "regression" (Hellinger distance), the verbose block still calls evaluate_classification_model, producing incorrect metrics. Branch on task:

         if verbose:
             test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False)
-            avg_loss_test, dict_results, _ = evaluate_classification_model(
-                model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose
-            )
+            if task == "regression":
+                avg_loss_test, dict_results, _ = evaluate_regression_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
+            else:
+                avg_loss_test, dict_results, _ = evaluate_classification_model(
+                    model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose
+                )
             print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")

860-862: Use standard PyTorch Geometric attribute num_nodes instead of num_gates.

Data objects expect num_nodes as the standard attribute name. Using num_gates may cause issues with some GNN operations that rely on the standard attribute.

-        feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates).to(
+        feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates).to(
             "cuda" if torch.cuda.is_available() else "cpu"
         )

866-873: Output tensor shape mismatch will cause assertion failure.

The GNN model returns a tensor of shape [1, num_devices] for a single Data object. len(outputs) returns 1 (batch dimension), not num_devices, so the assertion len(class_labels) != len(outputs) will always fail for multi-device cases.

Squeeze the batch dimension before comparison:

         with torch.no_grad():
             outputs = gnn_model(feature_vector)
+        outputs = outputs.squeeze(0)  # Remove batch dimension: [1, num_devices] -> [num_devices]
         assert class_labels is not None
         if len(class_labels) != len(outputs):
             msg = "outputs and class_labels must be same length"
             raise ValueError(msg)

-        pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)
+        pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)

327-337: Device score ordering may diverge from self.devices order.

Line 329 uses sorted(scores.keys()) for device ordering. If this differs from the sorted order of self.devices (sorted by dev.description at line 189-191), the y tensor indices won't match the model's expected output ordering. Consider using the canonical self.devices ordering:

             if self.gnn:
                 x, _y, edge_idx, n_nodes, target_label = training_sample
-                name_device = sorted(scores.keys())
-                value_device = [scores[i] for i in name_device]
+                # Use canonical device ordering from self.devices
+                value_device = [scores.get(dev.description, -1.0) for dev in self.devices]

421-422: Remove ambiguous commented code.

The commented alternative # list(scores.values()) suggests uncertainty about the intended value. Since scores is used directly as a dict, remove the comment.

-        scores_list = scores  # list(scores.values())
+        scores_list = scores

628-632: mlp_units list is not stored in JSON metadata; inference will fail.

Line 841-842 in predict_device_for_figure_of_merit expects to parse json_dict["mlp"] to reconstruct mlp_units, which works. However, json_dict["num_outputs"] is stored but the parsed mlp_units list itself is not. This is fine as long as the parsing logic matches exactly. Consider adding mlp_units explicitly for robustness:

         json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
+        json_dict["mlp_units"] = mlp_units  # Store parsed list for clarity

503-503: Parameter num_outputs is immediately overwritten, rendering it unused.

Line 503 overwrites num_outputs = max(1, len(self.devices)) regardless of the value passed as an argument. Either remove num_outputs from the function signature and docstring, or honor the argument when provided (e.g., use a sentinel value like None).

-        num_outputs: int,
+        num_outputs: int | None = None,
...
-        num_outputs = max(1, len(self.devices))
+        if num_outputs is None:
+            num_outputs = max(1, len(self.devices))

506-506: KFold(n_splits=1) is invalid and will raise an error for very small datasets.

When len(training_data.y_train) == 1, k_folds becomes 1, and KFold(n_splits=1) raises a ValueError. Guard against this:

-        k_folds = min(len(training_data.y_train), 5)
+        k_folds = max(2, min(len(training_data.y_train), 5))

Alternatively, skip hyperparameter search entirely for datasets with fewer than 2 samples.

tests/hellinger_distance/test_estimated_hellinger_distance.py (3)

293-324: Cleanup logic as a test is fragile and order-dependent.

test_remove_files acts as teardown but depends on test execution order. If tests run in parallel or a subset is selected via -k, cleanup may not run or may run prematurely.

Consider using a module-scoped autouse fixture with yield:

@pytest.fixture(scope="module", autouse=True)
def cleanup_test_artifacts(source_path: Path, target_path: Path):
    """Cleanup test artifacts after all tests in module complete."""
    yield  # Let tests run
    # Cleanup logic here...
    if source_path.exists():
        for file in source_path.iterdir():
            if file.suffix == ".qasm":
                file.unlink()
        source_path.rmdir()
    # ... rest of cleanup

307-317: Duplicate data_path computation can be consolidated.

Lines 307-311 and 313-317 both compute data_path and iterate over it. Consolidate into a single loop:

     data_path = get_path_training_data() / "training_data_aggregated"
     if data_path.exists():
         for file in data_path.iterdir():
-            if file.suffix == ".npy":
-                file.unlink()
-
-    data_path = get_path_training_data() / "training_data_aggregated"
-    if data_path.exists():
-        for file in data_path.iterdir():
-            if file.suffix == ".pt":
+            if file.suffix in (".npy", ".pt"):
                 file.unlink()

216-220: Docstring is outdated and doesn't reflect parametrized test behavior.

The docstring mentions "the Hellinger distance model that was trained in the previous test", but this test now parametrizes over RF vs GNN and trains internally. Update to reflect current behavior:

-    """Test the entire predictor toolchain with the Hellinger distance model that was trained in the previous test."""
+    """Test the full predictor toolchain for `estimated_hellinger_distance` for both RF and GNN backends."""
📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 77cca31 and 1acc0cb.

📒 Files selected for processing (2)
  • src/mqt/predictor/ml/predictor.py (12 hunks)
  • tests/hellinger_distance/test_estimated_hellinger_distance.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/mqt/predictor/ml/predictor.py (2)
src/mqt/predictor/ml/gnn.py (1)
  • GNN (191-295)
src/mqt/predictor/ml/helper.py (7)
  • TrainingData (575-585)
  • create_dag (192-301)
  • create_feature_vector (164-189)
  • get_openqasm3_gates (119-151)
  • get_path_trained_model (48-50)
  • get_path_trained_model_gnn (53-57)
  • get_path_training_data (36-38)
tests/hellinger_distance/test_estimated_hellinger_distance.py (1)
src/mqt/predictor/ml/helper.py (3)
  • TrainingData (575-585)
  • create_dag (192-301)
  • get_path_training_data (36-38)
🪛 Ruff (0.14.6)
src/mqt/predictor/ml/predictor.py

107-107: Boolean-typed positional argument in function definition

(FBT001)


107-107: Boolean default positional argument in function definition

(FBT002)


146-146: Logging statement uses f-string

(G004)


149-149: Logging statement uses f-string

(G004)


173-173: Boolean-typed positional argument in function definition

(FBT001)


173-173: Boolean default positional argument in function definition

(FBT002)


448-448: Boolean-typed positional argument in function definition

(FBT001)


448-448: Boolean default positional argument in function definition

(FBT002)


570-570: Boolean-typed positional argument in function definition

(FBT001)


570-570: Boolean default positional argument in function definition

(FBT002)


801-801: Boolean-typed positional argument in function definition

(FBT001)


801-801: Boolean default positional argument in function definition

(FBT002)

tests/hellinger_distance/test_estimated_hellinger_distance.py

158-158: Boolean-typed positional argument in function definition

(FBT001)


218-218: Boolean-typed positional argument in function definition

(FBT001)

🔇 Additional comments (1)
src/mqt/predictor/ml/predictor.py (1)

83-86: Type alias definitions look good.

The type aliases GNNSample, FeatureSample, and TrainingSample provide clear documentation of the expected tuple structures for both training paths.

conv_activation=torch.nn.functional.leaky_relu,
mlp_activation=torch.nn.functional.leaky_relu,
).to("cuda" if torch.cuda.is_available() else "cpu")
gnn_model.load_state_dict(torch.load(path))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing weights_only=True for torch.load poses a security risk.

torch.load(path) without weights_only=True can execute arbitrary code if the model file is malicious. Since you're only loading state dicts:

-        gnn_model.load_state_dict(torch.load(path))
+        gnn_model.load_state_dict(torch.load(path, weights_only=True))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
gnn_model.load_state_dict(torch.load(path))
gnn_model.load_state_dict(torch.load(path, weights_only=True))
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around line 858, replace the insecure
torch.load(path) call with torch.load(path, weights_only=True) so only tensor
weights are loaded; if your runtime device matters also pass map_location (e.g.,
map_location=device) and then feed the resulting state dict into
gnn_model.load_state_dict; optionally wrap the load in a try/except to surface a
clear error message on failure.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (5)
tests/hellinger_distance/test_estimated_hellinger_distance.py (5)

186-192: num_nodes should be an integer, not a tensor.

PyTorch Geometric's Data class expects num_nodes to be a plain Python int. Wrapping it in torch.tensor([n_nodes]) creates a 1D tensor, which can cause unexpected behavior during batching or graph operations.

             gnn_training_sample = Data(
                 x=x,
                 y=torch.tensor(labels_list[i], dtype=torch.float32),
                 edge_index=edge_idx,
-                num_nodes=torch.tensor([n_nodes]),
+                num_nodes=n_nodes,
             )

219-221: Docstring is outdated after parameterization.

The docstring still references "the Hellinger distance model that was trained in the previous test," but this test now trains its own model and is parameterized for both RF and GNN backends. Update to reflect the current behavior.

 def test_train_and_qcompile_with_hellinger_model(
     source_path: Path, target_path: Path, device: Target, gnn: bool
 ) -> None:
-    """Test the entire predictor toolchain with the Hellinger distance model that was trained in the previous test."""
+    """Test the entire predictor toolchain for estimated_hellinger_distance with both RF and GNN backends."""

307-317: Consolidate duplicate data_path computation.

data_path is computed twice with the same value. Merge the cleanup loops to avoid redundant computation.

     data_path = get_path_training_data() / "training_data_aggregated"
     if data_path.exists():
         for file in data_path.iterdir():
-            if file.suffix == ".npy":
+            if file.suffix in (".npy", ".pt"):
                 file.unlink()
-
-    data_path = get_path_training_data() / "training_data_aggregated"
-    if data_path.exists():
-        for file in data_path.iterdir():
-            if file.suffix == ".pt":
-                file.unlink()

319-323: Use tuple membership instead of chained or comparisons.

For readability and consistency, use in with a tuple.

-            if file.suffix == ".joblib" or file.suffix == ".pth" or file.suffix == ".json":
+            if file.suffix in (".joblib", ".pth", ".json"):

293-324: Consider converting cleanup test to a fixture.

test_remove_files acts as teardown for artifacts created by other tests. This couples correctness to test ordering and can break under parallel execution or when running a subset of tests with -k. Consider moving this logic to a module-scoped autouse fixture with yield to guarantee cleanup regardless of test selection or execution order.

Example fixture approach:

@pytest.fixture(scope="module", autouse=True)
def cleanup_test_artifacts(source_path: Path, target_path: Path):
    """Clean up test artifacts after all tests in the module complete."""
    yield
    # Cleanup logic here...
📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 64a94ff and c4cc835.

📒 Files selected for processing (1)
  • tests/hellinger_distance/test_estimated_hellinger_distance.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/hellinger_distance/test_estimated_hellinger_distance.py (3)
src/mqt/predictor/hellinger/utils.py (2)
  • calc_device_specific_features (37-132)
  • hellinger_distance (29-34)
src/mqt/predictor/ml/helper.py (3)
  • TrainingData (575-585)
  • create_dag (192-301)
  • get_path_training_data (36-38)
src/mqt/predictor/rl/helper.py (1)
  • get_path_training_data (97-99)
🪛 Ruff (0.14.6)
tests/hellinger_distance/test_estimated_hellinger_distance.py

158-158: Boolean-typed positional argument in function definition

(FBT001)


218-218: Boolean-typed positional argument in function definition

(FBT001)

⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
  • GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
  • GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (2)
tests/hellinger_distance/test_estimated_hellinger_distance.py (2)

21-31: LGTM!

The new imports for torch, torch_geometric.data (Batch, Data), and create_dag are appropriate for the GNN testing functionality.


254-265: Clarify the Windows + GNN timeout behavior.

The condition sys.platform == "win32" and not gnn means GNN tests on Windows skip the timeout warning path. If this is intentional (e.g., GNN compilation doesn't use the same timeout mechanism), a brief comment would help future maintainers understand this distinction.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: ASSERTIVE

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c4cc835 and 6ebded4.

📒 Files selected for processing (2)
  • tests/device_selection/test_predictor_ml.py (4 hunks)
  • tests/hellinger_distance/test_estimated_hellinger_distance.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/device_selection/test_predictor_ml.py (2)
src/mqt/predictor/ml/helper.py (1)
  • get_path_training_data (36-38)
src/mqt/predictor/ml/predictor.py (2)
  • predict_device_for_figure_of_merit (800-880)
  • Predictor (166-797)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py

42-42: Boolean-typed positional argument in function definition

(FBT001)


42-42: Boolean-typed positional argument in function definition

(FBT001)


115-115: Boolean-typed positional argument in function definition

(FBT001)

tests/hellinger_distance/test_estimated_hellinger_distance.py

158-158: Boolean-typed positional argument in function definition

(FBT001)


219-219: Boolean-typed positional argument in function definition

(FBT001)

⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
  • GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
  • GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
🔇 Additional comments (6)
tests/device_selection/test_predictor_ml.py (3)

38-43: Parameterization covers key scenarios appropriately.

The test matrix exercises RF-only, GNN without verbose, and GNN with verbose modes. This provides good coverage for the new GNN path integration.


68-75: Artifact assertions correctly differentiate GNN and RF paths.

The conditional logic now properly validates:

  • GNN path: graph_dataset_expected_fidelity.pt plus the shared .npy files
  • RF path: training_data_expected_fidelity.npy plus the shared .npy files

This addresses the earlier feedback about asserting names_list and scores_list existence for the GNN path.


114-117: Parameterized test correctly exercises both predictor backends.

The GNN flag is properly passed to Predictor, ensuring both RF and GNN backends are tested for the _get_prepared_training_data error path.

tests/hellinger_distance/test_estimated_hellinger_distance.py (3)

166-193: GNN training data construction is correct.

The branching logic properly:

  • Uses calc_device_specific_features for RF and create_dag for GNN
  • Constructs appropriate TrainingData objects for each path
  • Correctly passes n_nodes as an integer to Data (addressing earlier feedback)

196-214: Model training and evaluation paths are properly separated.

The conditional training logic correctly dispatches to train_gnn_model vs train_random_forest_model, and the evaluation uses appropriate methods for each model type. The tolerance comment at line 213 explains the rationale for atol=2e-1.


255-266: Clarify whether Windows timeout warning should apply to GNN path.

The condition sys.platform == "win32" and not gnn means the timeout warning test is skipped when gnn=True on Windows. Is this intentional because the GNN path handles timeouts differently, or should both paths warn on Windows?

If GNN compilation also uses timeouts that aren't supported on Windows, consider:

-        if sys.platform == "win32" and not gnn:
+        if sys.platform == "win32":

Otherwise, add a comment explaining why GNN is exempt from this warning.

Comment on lines 157 to 159
@pytest.mark.parametrize("gnn", [False, True], ids=["rf", "gnn"])
def test_train_model_and_predict(device: Target, gnn: bool) -> None:
"""Test the training of the random forest regressor. The trained model is saved and used in the following tests."""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Docstring no longer accurately reflects parameterized test behavior.

The docstring still says "Test the training of the random forest regressor" but the test is now parameterized to cover both RF and GNN backends.

 @pytest.mark.parametrize("gnn", [False, True], ids=["rf", "gnn"])
 def test_train_model_and_predict(device: Target, gnn: bool) -> None:
-    """Test the training of the random forest regressor. The trained model is saved and used in the following tests."""
+    """Test model training and prediction for both RF and GNN backends."""
🧰 Tools
🪛 Ruff (0.14.6)

158-158: Boolean-typed positional argument in function definition

(FBT001)

🤖 Prompt for AI Agents
In tests/hellinger_distance/test_estimated_hellinger_distance.py around lines
157 to 159, the test docstring incorrectly states it only tests training of the
random forest regressor while the test is parameterized to run for both RF and
GNN; update the docstring to accurately describe that the test trains and saves
a model for both backends (random forest and GNN) and that the saved model is
used in subsequent tests, and briefly document the gnn parameter meaning (False
=> RF, True => GNN).

Comment on lines +294 to +318
def test_remove_files(source_path: Path, target_path: Path) -> None:
"""Remove files created during testing."""
if source_path.exists():
for file in source_path.iterdir():
if file.suffix == ".qasm":
file.unlink()
source_path.rmdir()

if target_path.exists():
for file in target_path.iterdir():
if file.suffix == ".qasm":
file.unlink()
target_path.rmdir()

data_path = get_path_training_data() / "training_data_aggregated"
if data_path.exists():
for file in data_path.iterdir():
if file.suffix in (".npy", ".pt"):
file.unlink()

model_path = get_path_training_data() / "trained_model"
if model_path.exists():
for file in model_path.iterdir():
if file.suffix in (".joblib", ".pth", ".json"):
file.unlink()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

Consider converting cleanup to a fixture for test isolation.

test_remove_files acts as teardown but depends on test execution order. If tests are run in isolation or in parallel, cleanup may not occur. This was flagged in a previous review.

Convert to a session-scoped autouse fixture:

@pytest.fixture(scope="module", autouse=True)
def cleanup_test_artifacts(source_path: Path, target_path: Path):
    """Clean up test artifacts after all tests in the module complete."""
    yield  # Let tests run
    # Cleanup logic here...

This ensures cleanup runs regardless of test selection or ordering.

🤖 Prompt for AI Agents
In tests/hellinger_distance/test_estimated_hellinger_distance.py around lines
294 to 318, replace the standalone test_remove_files teardown function with a
pytest fixture that runs automatically for the module (or session) and performs
cleanup after tests by yielding control and then executing the same removal
logic; make the fixture use scope="module" (or "session") and autouse=True,
accept source_path and target_path as parameters, ensure the post-yield cleanup
is idempotent (check existence before unlink/rmdir or use safe removal) and
covers data_path and model_path cleanup so cleanup runs regardless of test
ordering or isolation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants