-
-
Notifications
You must be signed in to change notification settings - Fork 21
Add GNN-Based Predictor with DAG Preprocessing #430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…into gnn-branch
…into gnn-branch
…into gnn-branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (4)
src/mqt/predictor/ml/predictor.py (4)
503-506:num_outputsparameter is unused;k_foldsneeds guard for tiny datasets.Two issues persist from prior review:
Line 503: The
num_outputsparameter (line 442) is immediately overwritten bymax(1, len(self.devices)), making the parameter useless. Either honor the passed value or remove it from the signature.Line 506:
KFold(n_splits=k_folds)will raiseValueErrorifk_folds < 2. Sincetrain_gnn_modelsetsk_folds = min(len(training_data.y_train), 5), a dataset with 1 sample causes failure.Apply this diff:
- num_outputs = max(1, len(self.devices)) + # num_outputs already passed from caller; use it directly # Split into k-folds + k_folds = max(2, k_folds) # KFold requires at least 2 splits kf = KFold(n_splits=k_folds, shuffle=True)And either remove
num_outputsfrom the signature or usenum_outputs = num_outputs or max(1, len(self.devices)).
604-610: Dead code: device descriptions computed but not stored.Line 607 computes
[dev.description for dev in self.devices]but the result is discarded. This list is needed for inference to map model outputs back to device names.Store the device labels in the JSON metadata:
if num_outputs == 1: # task = "binary" task = "classification" - [dev.description for dev in self.devices] + else: + task = "classification" + device_classes = [dev.description for dev in self.devices] sampler_obj = TPESampler(n_startup_trials=10)Then add to
json_dictbefore saving:json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 + json_dict["classes"] = device_classes if self.figure_of_merit != "hellinger_distance" else None
696-701: Verbose evaluation ignores task; always uses classification metrics for regression.When
self.figure_of_merit == "estimated_hellinger_distance"or"hellinger_distance",task = "regression"(line 602), but the verbose block unconditionally callsevaluate_classification_model. This produces meaningless metrics for regression tasks.Branch on the task variable:
if verbose: test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False) - avg_loss_test, dict_results, _ = evaluate_classification_model( - model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose - ) + if task == "regression": + avg_loss_test, dict_results, _ = evaluate_regression_model( + model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose + ) + else: + avg_loss_test, dict_results, _ = evaluate_classification_model( + model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose + ) print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")
856-869: Critical:gnn_model.classesattribute does not exist; prediction will fail withAttributeError.Line 860 accesses
gnn_model.classes, but theGNNclass (fromgnn.py) never defines aclassesattribute. This will cause anAttributeErrorat runtime.Additionally:
- Line 856:
torch.load(path)should specifymap_locationfor CPU/GPU portability andweights_only=Truefor security (state dicts are safe).- Line 847: Will raise
KeyErrorbecause"mlp_units"is not stored in JSON (only"mlp"as a string).The class labels should be loaded from the JSON metadata (once stored during training):
- gnn_model.load_state_dict(torch.load(path)) + gnn_model.load_state_dict(torch.load(path, map_location="cpu", weights_only=True)) x, edge_index, number_of_gates = create_dag(qc) feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates) gnn_model.eval() - class_labels = gnn_model.classes + class_labels = json_dict.get("classes") # Load from JSON instead of model attribute with torch.no_grad(): outputs = gnn_model(feature_vector)And fix the
mlp_unitsloading:- mlp_units=json_dict["mlp_units"], + mlp_units=([] if json_dict["mlp"] == "none" else [int(x) for x in json_dict["mlp"].split(",")]),
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (1)
src/mqt/predictor/ml/predictor.py(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/mqt/predictor/ml/predictor.py (3)
src/mqt/predictor/ml/gnn.py (1)
GNN(191-295)src/mqt/predictor/hellinger/utils.py (1)
get_hellinger_model_path(135-144)src/mqt/predictor/ml/helper.py (4)
TrainingData(569-579)create_dag(192-301)create_feature_vector(164-189)get_openqasm3_gates(119-151)
🪛 Ruff (0.14.5)
src/mqt/predictor/ml/predictor.py
107-107: Boolean-typed positional argument in function definition
(FBT001)
107-107: Boolean default positional argument in function definition
(FBT002)
146-146: Logging statement uses f-string
(G004)
149-149: Logging statement uses f-string
(G004)
173-173: Boolean-typed positional argument in function definition
(FBT001)
173-173: Boolean default positional argument in function definition
(FBT002)
448-448: Boolean-typed positional argument in function definition
(FBT001)
448-448: Boolean default positional argument in function definition
(FBT002)
571-571: Boolean-typed positional argument in function definition
(FBT001)
571-571: Boolean default positional argument in function definition
(FBT002)
802-802: Boolean-typed positional argument in function definition
(FBT001)
802-802: Boolean default positional argument in function definition
(FBT002)
⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
- GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
- GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (5)
src/mqt/predictor/ml/predictor.py (5)
83-98: Type definitions are well-structured.The type aliases (
GNNSample,FeatureSample,TrainingSample) andTrainGNNKwargsTypedDict provide good type safety for the GNN pathway.
100-163: Clean conditional training logic.The implementation correctly branches between RF and GNN training based on the
gnnflag, and the**gnn_kwargspattern cleanly forwards GNN-specific parameters. The boolean parameter is acceptable here for API simplicity.
326-341: Data construction for GNN samples is reasonable, but note the y-tensor semantics.The code discards the
_yfromtraining_sample(which was the target device index) and replaces it withvalue_device(all device scores). This makes sense for multi-output regression but differs from the classification label. Ensure downstream training/evaluation code expects this multi-output format.
421-434: Return type and GNN sample creation are correct.The GNN branch properly creates the tuple with node features, target index, edge index, gate count, and target label. The
scores_listnow returns the dict directly (despite the name), which is used correctly ingenerate_training_data.
769-771:weights_only=Falseis required but be aware of security implications.Using
torch.load(..., weights_only=False)is necessary to deserializetorch_geometric.data.Dataobjects, but it allows arbitrary code execution if the.ptfile is malicious. Ensure training data files are from trusted sources only.
| dict_best_hyper = study.best_trial.params # user_attrs.get("best_hparams") | ||
| # Build model (ensure final layer outputs raw logits/no activation) | ||
| json_dict = study.best_trial.params | ||
| mlp_str = dict_best_hyper["mlp"] | ||
| mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")] | ||
|
|
||
| json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mlp_units not stored in JSON metadata; prediction will fail.
Line 847 in predict_device_for_figure_of_merit reads json_dict["mlp_units"], but json_dict is populated from study.best_trial.params which stores "mlp" as a string (e.g., "128,64"), not "mlp_units" as a list.
Add the parsed mlp_units to the JSON before saving:
mlp_str = dict_best_hyper["mlp"]
mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")]
json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1
+ json_dict["mlp_units"] = mlp_units # Store parsed list for inference📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| dict_best_hyper = study.best_trial.params # user_attrs.get("best_hparams") | |
| # Build model (ensure final layer outputs raw logits/no activation) | |
| json_dict = study.best_trial.params | |
| mlp_str = dict_best_hyper["mlp"] | |
| mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")] | |
| json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 | |
| dict_best_hyper = study.best_trial.params # user_attrs.get("best_hparams") | |
| # Build model (ensure final layer outputs raw logits/no activation) | |
| json_dict = study.best_trial.params | |
| mlp_str = dict_best_hyper["mlp"] | |
| mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")] | |
| json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 | |
| json_dict["mlp_units"] = mlp_units # Store parsed list for inference |
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 627 to 633, json_dict is
populated from study.best_trial.params which contains "mlp" as a comma string,
but later code expects json_dict["mlp_units"]; add the parsed mlp_units list
into json_dict before any saving/serialization. Specifically, after computing
mlp_units (from mlp_str), set json_dict["mlp_units"] = mlp_units (or [] when
"none") and ensure num_outputs remains set as before so downstream code can read
mlp_units as a list.
| outputs = gnn_model(feature_vector) | ||
| assert class_labels is not None | ||
| if len(class_labels) != len(outputs): | ||
| msg = "outputs and class_labels must be same length" | ||
| raise ValueError(msg) | ||
|
|
||
| pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True) | ||
| sorted_devices = np.array([label for _, label in pairs]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll help you verify this review comment about the GNN model batching issue. Let me start by examining the actual code to understand the context.
<function_calls>
cat -n src/mqt/predictor/ml/predictor.py | sed -n '850,880p'
</function_calls>
cat -n src/mqt/predictor/ml/predictor.py | sed -n '1,50p'
</function_calls>
wc -l src/mqt/predictor/ml/predictor.py
</function_calls>
Let me check the file size and read the relevant sections:
<function_calls>
head -50 src/mqt/predictor/ml/predictor.py | cat -n
</function_calls>
Now let me read the specific lines mentioned in the review:
<function_calls>
sed -n '850,880p' src/mqt/predictor/ml/predictor.py | cat -n
</function_calls>
Let me also check the imports and the broader context around the function:
<function_calls>
sed -n '800,870p' src/mqt/predictor/ml/predictor.py | cat -n
</function_calls>
Let me search for the function containing these lines:
<function_calls>
sed -n '830,870p' src/mqt/predictor/ml/predictor.py | cat -n
</function_calls>
Let me try a different approach to find the function context:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print lines 850-870 with context
for i in range(849, min(870, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Let me also look at the imports and find the create_dag function:
<function_calls>
rg -n "def create_dag" src/mqt/predictor/ml/predictor.py
</function_calls>
Let me check what comes before line 862:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print lines 840-875 with context
for i in range(839, min(875, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Let me also search for gnn_model and DataLoader usage:
<function_calls>
rg -n "gnn_model" src/mqt/predictor/ml/predictor.py | head -20
</function_calls>
Let me check the imports and broader context:
<function_calls>
rg -n "from torch_geometric" src/mqt/predictor/ml/predictor.py
</function_calls>
Let me look at the function definition that contains line 862:
<function_calls>
rg -n "def " src/mqt/predictor/ml/predictor.py | grep -E "^[0-9]+:" | awk -F: '{print $1}' | while read line; do
if [ "$line" -lt 862 ]; then
last_func=$line
fi
done
echo "Checking around line 862..."
python3 << 'PYEOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Find the function containing line 862
for i in range(861, -1, -1):
if lines[i].strip().startswith('def '):
print(f"Function at line {i+1}: {lines[i].strip()}")
break
PYEOF
</function_calls>
Let me use a more direct approach:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Find function containing line 862 (index 861)
for i in range(861, -1, -1):
if lines[i].strip() and lines[i][0] not in (' ', '\t'):
if lines[i].strip().startswith('def '):
print(f"Function at line {i+1}: {lines[i].strip()}\n")
# Print from function def to line 880
for j in range(i, min(880, len(lines))):
print(f"{j+1:4d}: {lines[j]}", end='')
break
EOF
</function_calls>
Now let me check the training code mentioned at line 335:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print around line 335
for i in range(max(0, 325), min(350, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Let me also look for the Data class usage and how feature_vector is being used:
<function_calls>
rg -n "Data(" src/mqt/predictor/ml/predictor.py -B 2 -A 2
</function_calls>
Let me check the model's eval() method and how it's called:
<function_calls>
rg -n ".eval()" src/mqt/predictor/ml/predictor.py -B 3 -A 3
</function_calls>
Let me check the exact lines 862-869 more carefully:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print the exact range mentioned in the review
print("Lines 862-869:")
for i in range(861, 869):
if i < len(lines):
print(f"{i+1:4d}: {lines[i]}", end='')
print("\n\nContext (lines 850-875):")
for i in range(849, min(875, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Let me now look at model creation and training to understand how it's built:
<function_calls>
rg -n "GNN|GraphNeural|gnn_model\s*=" src/mqt/predictor/ml/predictor.py -B 2 -A 5
</function_calls>
Let me check if there's a GNN architecture definition:
<function_calls>
find src -name "*.py" | xargs rg -l "class.*GNN|GraphNeural" 2>/dev/null
</function_calls>
Let me check what gnn_model is loaded from:
<function_calls>
rg -n "gnn_model\s*=" src/mqt/predictor/ml/predictor.py -B 5 -A 2
</function_calls>
Let me search more broadly for how the model is instantiated or loaded:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
content = f.read()
Find all lines with gnn_model
import re
matches = re.finditer(r'.gnn_model.', content)
for match in matches:
# Find line number
line_num = content[:match.start()].count('\n') + 1
print(f"Line {line_num}: {match.group()}")
EOF
</function_calls>
Let me look at the function definition more carefully:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Find function definition containing line 862
current_func = None
for i in range(len(lines)):
if lines[i][0] not in (' ', '\t') and lines[i].strip().startswith('def '):
current_func = (i+1, lines[i].strip())
if i == 861: # Line 862 is index 861
if current_func:
print(f"Line 862 is in function starting at line {current_func[0]}: {current_func[1]}\n")
# Print the function starting from its definition
start = current_func[0] - 1
for j in range(start, min(start + 60, len(lines))):
print(f"{j+1:4d}: {lines[j]}", end='')
break
EOF
</function_calls>
Perfect! Now let me understand the actual behavior by looking at what happens with feature_vector:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print broader context around lines 862-869
print("Extended context (lines 840-890):")
for i in range(839, min(890, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Now let me check the imports to see if Data and Batch are already imported:
<function_calls>
head -50 src/mqt/predictor/ml/predictor.py | cat -n
</function_calls>
Let me see all imports:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
print("Imports section:")
for i in range(len(lines)):
if lines[i].strip() and not (lines[i].startswith('import ') or lines[i].startswith('from ')):
if i > 1: # Stop after imports
break
if lines[i].strip():
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Let me search for the full imports more carefully:
<function_calls>
rg -n "^import |^from " src/mqt/predictor/ml/predictor.py | head -30
</function_calls>
Now let me verify how the model uses the input by checking if there are tests or other usages:
<function_calls>
rg -n "forward|call" src/mqt/predictor/ml/predictor.py -B 2 -A 5
</function_calls>
Let me check the entire function containing lines 862-869:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Find the function definition
for i in range(len(lines)):
if i >= 780 and i <= 870:
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Let me search for the actual function that contains these lines:
<function_calls>
rg -n "def " src/mqt/predictor/ml/predictor.py | tail -20
</function_calls>
Let me be more systematic:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
total = len(lines)
print(f"Total lines: {total}")
# Find all function definitions and their line numbers
funcs = []
for i in range(total):
if lines[i][0] not in (' ', '\t') and 'def ' in lines[i]:
funcs.append((i+1, lines[i].strip()))
print("\nFunctions in file:")
for line_num, func_def in funcs[-15:]: # Last 15 functions
print(f" Line {line_num}: {func_def[:60]}")
# Find function containing line 862
target = 862
for i, (line_num, func_def) in enumerate(funcs):
next_line = funcs[i+1][0] if i+1 < len(funcs) else total
if line_num <= target < next_line:
print(f"\nLine {target} is in function: {func_def}")
print(f"Function spans lines {line_num}-{next_line-1}\n")
# Print the function
for j in range(line_num-1, min(line_num+80, next_line)):
print(f"{j+1:4d}: {lines[j]}", end='')
break
EOF
</function_calls>
Excellent! Now I can see the actual context. Let me verify the Data import and check how it's used in training:
<function_calls>
rg -n "from torch_geometric.data import|class Data" src/mqt/predictor/ml/predictor.py
</function_calls>
Now let me check if there's a training function to understand how the model was trained:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print lines around 335 mentioned in review
print("Context around line 335:")
for i in range(max(0, 325), min(360, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Let me check if the gnn_model actually supports single Data objects by searching for its usage patterns:
<function_calls>
rg -n "gnn_model(" src/mqt/predictor/ml/predictor.py -B 3 -A 3
</function_calls>
Now let me search for the actual GNN model definition in the repository:
<function_calls>
find src -name "*.py" -exec rg -l "class.*NN|class.*GNN" {} ;
</function_calls>
Let me check what gnn_model is (where it comes from):
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Search for gnn_model initialization or loading
for i in range(len(lines)):
if 'gnn_model' in lines[i]:
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Perfect! Now I can see the actual context. Let me verify if there are issues by checking how the model is called in training vs prediction. Let me look at line 335 area more carefully:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Find function containing line 335
for i in range(len(lines)):
if i >= 320 and i <= 370:
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Now let me search for DataLoader usage to understand how the model is trained:
<function_calls>
rg -n "DataLoader" src/mqt/predictor/ml/predictor.py -B 3 -A 3
</function_calls>
Let me check what the create_dag function returns:
<function_calls>
rg -n "def create_dag" src/mqt/predictor/ml/predictor.py -A 20
</function_calls>
Let me check if there's documentation on the gnn_model or its architecture:
<function_calls>
rg -n "GAT|GCN|GraphConv|graph_model" src/mqt/predictor/ml/predictor.py -B 2 -A 2
</function_calls>
Let me search the entire repository for GNN model definitions:
<function_calls>
find src/mqt/predictor -name "*.py" | xargs rg -l "GAT|GCN|GraphConv"
</function_calls>
Let me check what files are in the ml directory:
<function_calls>
ls -la src/mqt/predictor/ml/
</function_calls>
Let me check the model file:
<function_calls>
wc -l src/mqt/predictor/ml/gnn_models.py
</function_calls>
<function_calls>
head -100 src/mqt/predictor/ml/gnn_models.py | cat -n
</function_calls>
Let me search for pooling layers and forward methods in the model file:
<function_calls>
rg -n "def forward|pool|Pool" src/mqt/predictor/ml/gnn_models.py -B 2 -A 5
</function_calls>
Let me look at the full GNN model class:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/gnn_models.py', 'r') as f:
lines = f.readlines()
# Print first 150 lines to see model structure
for i in range(min(150, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Now let me check the forward method more carefully:
<function_calls>
rg -n "def forward" src/mqt/predictor/ml/gnn_models.py -A 30
</function_calls>
Perfect! Now I understand the model. Let me verify if the issue about batching is valid. Let me check how the training code uses Data objects:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Find how Data is used in training (around line 335)
for i in range(max(0, 320), min(380, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Let me search for the training function that creates Data objects:
<function_calls>
rg -n "def .*train|Data(" src/mqt/predictor/ml/predictor.py | head -20
</function_calls>
Let me examine how Data objects are created in training:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Find Data( usage
for i in range(len(lines)):
if 'Data(' in lines[i]:
# Print context
start = max(0, i-2)
end = min(len(lines), i+3)
for j in range(start, end):
marker = ">>> " if j == i else " "
print(f"{marker}{j+1:4d}: {lines[j]}", end='')
print()
EOF
</function_calls>
Now let me verify how the model inference actually works and whether batching is necessary. Let me check if the model supports single Data objects:
<function_calls>
python3 << 'EOF'
Let's trace through the model's forward method
with open('src/mqt/predictor/ml/gnn_models.py', 'r') as f:
lines = f.readlines()
# Find the forward method
in_forward = False
for i in range(len(lines)):
if 'def forward' in lines[i]:
in_forward = True
if in_forward:
print(f"{i+1:4d}: {lines[i]}", end='')
if i > 0 and lines[i][0] not in (' ', '\t') and i > 10:
if 'def ' in lines[i]:
break
if i > 100: # Safety limit
break
EOF
</function_calls>
Now let me check if global_add_pool can handle single graphs:
<function_calls>
rg -n "global_add_pool|GlobalAddPool" src/mqt/predictor/ml/gnn_models.py
</function_calls>
Let me trace through what happens when you call a PyTorch Geometric model with a single Data object vs a Batch:
<function_calls>
python3 << 'EOF'
Check PyTorch Geometric documentation behavior
A single Data object has batch=None by default
When passed through GCN/GAT layers, it works fine
But global_add_pool expects to work with batched graphs
Let's verify this understanding
print("Understanding PyTorch Geometric behavior:")
print("- Single Data object has no 'batch' attribute (or batch=None)")
print("- GCN/GAT layers work fine with single graphs")
print("- global_add_pool expects batch tensor for multiple graphs")
print("- But for single graph, it should return output[0] equivalent")
print()
print("The key question: what does global_add_pool return for single Data?")
print("In torch_geometric.nn.glob, global_add_pool should sum all node features")
print("For single graph, this gives shape [1, out_features]")
EOF
</function_calls>
Now let me check the actual predict function to see what happens with outputs:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print lines 862-875 to see what happens with outputs
for i in range(861, min(880, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Now I need to understand what the actual output shape is. Let me look at what the model returns:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/gnn_models.py', 'r') as f:
lines = f.readlines()
# Find forward method return statement
in_forward = False
for i in range(len(lines)):
if 'def forward' in lines[i]:
in_forward = True
start = i
if in_forward and 'return' in lines[i]:
# Print the forward method and return
for j in range(start, i+1):
print(f"{j+1:4d}: {lines[j]}", end='')
break
EOF
</function_calls>
Great! Now I can see that the forward method returns self.lin(out) where out is the result of global_add_pool. For a single Data object:
- global_add_pool returns shape [1, hidden_dim]
- self.lin (Linear layer) returns shape [1, num_labels]
So outputs will have shape [1, num_labels]. Let me verify that this matches with class_labels:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Look at the predict function more carefully
# to see what class_labels is
for i in range(max(0, 810), min(900, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Now I see the issue! Let me verify:
class_labelsis defined as a parameter (probably list of device labels)outputsfrom the model has shape [1, num_devices]- The code does
len(class_labels) != len(outputs)where:len(class_labels)= number of devices (e.g., 10)len(outputs)= 1 (the batch dimension)
This is indeed a shape mismatch! Let me verify what should happen:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Find function signature to understand class_labels
for i in range(max(0, 800), min(870, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
if i == 820:
break
print("\n\n=== Looking for function signature ===\n")
# Find the predict function def
for i in range(len(lines)):
if 'def predict' in lines[i]:
for j in range(i, min(i+15, len(lines))):
print(f"{j+1:4d}: {lines[j]}", end='')
break
EOF
</function_calls>
Perfect! Now I can see that class_labels is a parameter to the function. The issue is clear:
The model returns a tensor of shape [1, num_devices] (since it's a single Data object), but class_labels is a list of device labels. The comparison len(class_labels) != len(outputs) compares:
len(class_labels)= number of elements in the labels list (e.g., 10)len(outputs)= 1 (batch dimension)
This will always fail unless there's only one device label!
Let me also verify the num_nodes vs num_gates issue:
<function_calls>
rg -n "num_nodes|num_gates" src/mqt/predictor/ml/predictor.py
</function_calls>
Let me check line 335 to see if num_nodes is used there:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Look around line 335
for i in range(max(0, 325), min(350, len(lines))):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Let me search more broadly for how Data objects are created with these attributes:
<function_calls>
rg -n "Data(" src/mqt/predictor/ml/predictor.py -B 5 -A 2
</function_calls>
Now I have a complete picture. Let me verify one more thing - whether num_nodes or num_gates is the correct attribute for torch_geometric.data.Data:
<function_calls>
python3 << 'EOF'
Check PyTorch Geometric Data documentation
The Data class expects attributes like:
- x: node features
- edge_index: edges
- num_nodes: automatically inferred or can be set manually
The review mentions num_gates should be num_nodes
Let's verify if num_gates is used anywhere
import subprocess
result = subprocess.run(['rg', '-n', 'num_gates', 'src/mqt/predictor/ml/'],
capture_output=True, text=True)
print("Occurrences of 'num_gates':")
print(result.stdout)
EOF
</function_calls>
Let me check if the Data class is even using num_gates:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
content = f.read()
Count occurrences
import re
num_gates = len(re.findall(r'num_gates', content))
num_nodes = len(re.findall(r'num_nodes', content))
print(f"num_gates occurrences: {num_gates}")
print(f"num_nodes occurrences: {num_nodes}")
Print the lines with these
for match in re.finditer(r'.num_gates.', content):
line_num = content[:match.start()].count('\n') + 1
print(f"Line {line_num}: {match.group()}")
EOF
</function_calls>
Now let me verify the actual issues by examining the code more thoroughly:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print the complete predict function
in_func = False
func_indent = 0
for i in range(len(lines)):
if 'def predict' in lines[i]:
in_func = True
func_indent = len(lines[i]) - len(lines[i].lstrip())
print(f"Found predict function at line {i+1}\n")
if in_func:
current_indent = len(lines[i]) - len(lines[i].lstrip()) if lines[i].strip() else float('inf')
# Check if we've left the function
if current_indent <= func_indent and lines[i].strip() and i > 805:
break
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Excellent! Now I have the complete picture. Let me analyze the issues:
-
Batching issue: The model expects outputs to match class_labels length. With a single Data object:
- Model returns shape [1, num_devices]
- class_labels has length = number of devices
len(outputs)= 1,len(class_labels)= num_devices- This will cause the assertion to fail!
-
num_gates vs num_nodes: The Data object is created with
num_gates=number_of_gates, but PyTorch Geometric's Data class doesn't have anum_gatesattribute. This is likely a bug.
Let me verify how to properly use torch_geometric.data.Data:
<function_calls>
python3 << 'EOF'
Check if num_gates is used as a custom attribute or if it should be num_nodes
In torch_geometric.data.Data, standard attributes are:
- x: node features [num_nodes, num_node_features]
- edge_index: [2, num_edges]
- y: target
- num_nodes: number of nodes (optional, auto-inferred)
Custom attributes can be added, but num_nodes is the standard
print("PyTorch Geometric Data attributes:")
print("- x: node features")
print("- edge_index: edge indices")
print("- num_nodes: number of nodes (standard attribute)")
print("- num_gates: NOT a standard attribute")
print()
print("When using Data with models expecting num_nodes:")
print("- Either set num_nodes explicitly")
print("- Or let it be auto-inferred from x.size(0)")
EOF
</function_calls>
Now I have enough information to verify the review comment. Let me check one more time if Batch is actually needed or if the model works fine with single Data:
<function_calls>
python3 << 'EOF'
The real issue is:
1. outputs.tolist() is called on line 867
This suggests outputs is a tensor
2. zip(outputs.tolist(), class_labels, strict=False)
This is trying to zip them together
If outputs shape is [1, num_devices] and class_labels has num_devices elements:
- outputs.tolist() would be a list of 1 element (which is itself a list)
- class_labels is a list of num_devices elements
- zip() would only iterate 1 time (takes shortest iterable)
So the real issue is the shape mismatch, not necessarily about batching
The fix should be to squeeze the batch dimension OR
use outputs[0] to get the actual predictions
print("The actual issue:")
print("If outputs has shape [1, num_devices]:")
print("- outputs.tolist() = [[score1, score2, ...]]")
print("- class_labels = [label1, label2, ...]")
print("- len(outputs) = 1, len(class_labels) = num_devices")
print()
print("The assertion 'len(class_labels) != len(outputs)' will fail")
print()
print("The fix should use:")
print("- outputs[0] to squeeze the batch dimension")
print("- OR outputs.squeeze() to remove the batch dimension")
EOF
</function_calls>
Now let me verify my understanding of the actual code behavior by checking what outputs looks like and how it's used:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Print lines around the problematic code
for i in range(860, 875):
if i < len(lines):
print(f"{i+1:4d}: {lines[i]}", end='')
EOF
</function_calls>
Perfect! Now I have all the information I need. The review comment is partially correct but with some important nuances:
-
✅ Shape mismatch is real: The assertion will indeed fail because:
outputsfrom model is shape [1, num_devices]class_labelsis a list of lengthnum_deviceslen(outputs)returns 1 (batch dimension), not num_devices
-
✅ num_gates issue is real: The Data object uses
num_gateswhich is not a standard PyTorch Geometric attribute. It should likely benum_nodes. -
❓ Batching solution correctness: The suggested fix using
Batch.from_data_list()would work, but it's arguably overcomplicating. A simpler fix would be to squeeze the output:outputs[0]oroutputs.squeeze(0).
Let me check if there are any other test files that show how this should work:
<function_calls>
find . -path ./\.git -prune -o -name "test.py" -type f -print | grep -i predictor | head -5
</function_calls>
Let me check for test files:
<function_calls>
find src -name "test.py" -o -name "test_*" -type d
</function_calls>
Let me check if there's a tests directory:
<function_calls>
ls -la | grep -i test
</function_calls>
<function_calls>
find . -maxdepth 3 -name "test*" -type d 2>/dev/null
</function_calls>
Let me verify the current behavior one more time by looking at the exact line that's causing the issue:
<function_calls>
python3 << 'EOF'
with open('src/mqt/predictor/ml/predictor.py', 'r') as f:
lines = f.readlines()
# Focus on lines 862-869
print("=== Lines 862-869 (the issue) ===\n")
for i in range(861, 869):
print(f"{i+1:4d}: {lines[i]}", end='')
print("\n\n=== Let's trace the issue ===\n")
print("1. Line 862: outputs = gnn_model(feature_vector)")
print(" - feature_vector is a single Data object")
print(" - Model returns tensor of shape [1, num_devices]")
print(" - So outputs has shape [1, num_devices]")
print()
print("2. Line 863: assert class_labels is not None")
print(" - class_labels is a list passed as parameter")
print()
print("3. Line 864-865: if len(class_labels) != len(outputs):")
print(" - len(class_labels) = number of devices (e.g., 10)")
print(" - len(outputs) = 1 (batch dimension!)")
print(" - MISMATCH! Will raise ValueError")
print()
print("4. Line 867: pairs = sorted(zip(outputs.tolist(), class_labels, strict=False)...)")
print(" - outputs.tolist() with shape [1, 10] becomes [[...10 scores...]]")
print(" - zip() will only iterate once (length 1)")
print(" - This is wrong! Should iterate num_devices times")
EOF
</function_calls>
Perfect! Now I'm ready to write the rewritten review comment. The review comment is fundamentally correct about the issue, but I should verify:
- Whether Batch is the right solution or if there's a simpler fix
- Whether num_gates/num_nodes is actually causing issues
Let me check one final time if the model training uses num_nodes:
<function_calls>
grep -n "num_nodes|num_gates" src/mqt/predictor/ml/predictor.py | head -20
</function_calls>
Based on my thorough investigation, I can now provide the rewritten review comment:
GNN model receives unbatched single Data object; output shape mismatches class_labels length.
Line 862 passes a single Data object to gnn_model(feature_vector). The model's global_add_pool layer returns a tensor of shape [1, num_devices], but line 864-865 compares len(outputs) (which is 1, the batch dimension) against len(class_labels) (which is the number of device labels, typically > 1). This assertion will fail. The subsequent zip on line 867 expects to iterate over num_devices predictions, not one batch-dimension element.
Additionally, line 862 creates a Data object with num_gates=number_of_gates, but PyTorch Geometric's Data class expects num_nodes as the standard attribute name. While custom attributes are permitted, using the standard name is more reliable.
Fix by removing the batch dimension from the model output:
outputs = gnn_model(feature_vector)
outputs = outputs.squeeze(0) # Remove batch dimension [1, num_devices] -> [num_devices]
assert class_labels is not None
if len(class_labels) != len(outputs):
msg = "outputs and class_labels must be same length"
raise ValueError(msg)
pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)Also update line 862 to use the standard PyTorch Geometric attribute:
feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 862-869, the GNN is called
with a single Data object so model returns shape [1, num_devices] and the code
incorrectly compares len(outputs) (==1) to len(class_labels) (==num_devices) and
zips them, causing a mismatch; fix by removing the batch dimension from outputs
(e.g., outputs = outputs.squeeze(0) or outputs = outputs[0]) before length
checks and zipping, and also replace the custom Data attribute num_gates with
the standard PyTorch Geometric attribute num_nodes when constructing the Data
object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (2)
src/mqt/predictor/ml/predictor.py (2)
601-701: Critical: Multiple issues will cause runtime failures in GNN prediction.Several previously flagged issues remain unresolved:
Line 607: Dead code — list comprehension result is discarded:
[dev.description for dev in self.devices] # Result not assignedClass labels not persisted:
json_dictstores only hyperparameters, butpredict_device_for_figure_of_merit(line 860) expectsgnn_model.classeswhich is never set. You must save the device labels:json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 +json_dict["classes"] = [dev.description for dev in self.devices]Line 698: Verbose evaluation always calls
evaluate_classification_model, even whentask="regression"(Hellinger distance). This produces meaningless metrics:if verbose: test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False) - avg_loss_test, dict_results, _ = evaluate_classification_model( - model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose - ) + if task == "regression": + avg_loss_test, dict_results, _ = evaluate_regression_model( + model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose + ) + else: + avg_loss_test, dict_results, _ = evaluate_classification_model( + model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose + ) print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")
503-506: Critical:num_outputsparameter is unused andk_foldsmay be invalid.Two issues persist from previous reviews:
Line 503: The
num_outputsparameter is immediately overwritten, making it a dead parameter. Either remove it from the signature or honor the passed value.Line 506:
KFold(n_splits=k_folds)will raiseValueErrorifk_folds < 2. Whenlen(training_data.y_train) == 1,k_foldsbecomes 1.- num_outputs = max(1, len(self.devices)) + # Use provided num_outputs, or fall back to device count + num_outputs = num_outputs if num_outputs > 0 else max(1, len(self.devices)) # Split into k-folds + k_folds = max(2, k_folds) # KFold requires n_splits >= 2 kf = KFold(n_splits=k_folds, shuffle=True)Alternatively, skip hyperparameter search entirely for datasets with fewer than 2 samples.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (1)
src/mqt/predictor/ml/predictor.py(12 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/mqt/predictor/ml/predictor.py (4)
src/mqt/predictor/ml/gnn.py (1)
GNN(191-295)src/mqt/predictor/hellinger/utils.py (1)
get_hellinger_model_path(135-144)src/mqt/predictor/ml/helper.py (10)
TrainingData(569-579)create_dag(192-301)create_feature_vector(164-189)evaluate_classification_model(323-390)evaluate_regression_model(393-459)get_openqasm3_gates(119-151)get_path_trained_model(48-50)get_path_trained_model_gnn(53-57)get_path_training_data(36-38)train_model(462-565)src/mqt/predictor/rl/predictor.py (1)
train_model(87-127)
🪛 Ruff (0.14.5)
src/mqt/predictor/ml/predictor.py
107-107: Boolean-typed positional argument in function definition
(FBT001)
107-107: Boolean default positional argument in function definition
(FBT002)
146-146: Logging statement uses f-string
(G004)
149-149: Logging statement uses f-string
(G004)
173-173: Boolean-typed positional argument in function definition
(FBT001)
173-173: Boolean default positional argument in function definition
(FBT002)
448-448: Boolean-typed positional argument in function definition
(FBT001)
448-448: Boolean default positional argument in function definition
(FBT002)
571-571: Boolean-typed positional argument in function definition
(FBT001)
571-571: Boolean default positional argument in function definition
(FBT002)
802-802: Boolean-typed positional argument in function definition
(FBT001)
802-802: Boolean default positional argument in function definition
(FBT002)
⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
- GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
- GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (3)
src/mqt/predictor/ml/predictor.py (3)
83-98: Type definitions look good.The
GNNSample,FeatureSample,TrainingSampletype aliases andTrainGNNKwargsTypedDict provide clear typing for the GNN data flow.
322-353: Training data generation logic is correct.The GNN branch properly creates
Dataobjects with multi-device scores as targets and preservestarget_labelfor later use. The conditional serialization to.ptvs.npyis appropriate.
769-771: Security note:weights_only=Falseallows arbitrary code execution.Using
torch.load(..., weights_only=False)is necessary to deserializetorch_geometric.data.Dataobjects, but it can execute arbitrary Python code if the.ptfile is malicious. This is acceptable for training data generated by the pipeline itself, but ensure this file is never loaded from untrusted sources.Consider adding a comment documenting this security consideration.
| gnn: bool = False, | ||
| **gnn_kwargs: Unpack[TrainGNNKwargs], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Consider keyword-only for gnn parameter.
While functional, making gnn keyword-only (after *) would prevent accidental positional usage and address the Ruff FBT001/FBT002 hints consistently across the codebase.
def setup_device_predictor(
devices: list[Target],
figure_of_merit: figure_of_merit = "expected_fidelity",
path_uncompiled_circuits: Path | None = None,
path_compiled_circuits: Path | None = None,
path_training_data: Path | None = None,
timeout: int = 600,
+ *,
gnn: bool = False,
**gnn_kwargs: Unpack[TrainGNNKwargs],
) -> bool:🧰 Tools
🪛 Ruff (0.14.5)
107-107: Boolean-typed positional argument in function definition
(FBT001)
107-107: Boolean default positional argument in function definition
(FBT002)
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 107-108, make the gnn
parameter keyword-only to avoid accidental positional use and satisfy Ruff
FBT001/FBT002: change the function signature so gnn appears after a bare *
(e.g., add a * before gnn) so callers must pass gnn by name, keep its type and
default (gnn: bool = False) and keep **gnn_kwargs unchanged; then update any
local call sites that passed gnn positionally to use the keyword form and run
linters/tests to ensure no breakage.
| scores_list = scores # list(scores.values()) | ||
| target_label = max(scores, key=lambda k: scores[k]) | ||
|
|
||
| qc = QuantumCircuit.from_qasm_file(path_uncompiled_circuit / file) | ||
| feature_vec = create_feature_vector(qc) | ||
| training_sample = (feature_vec, target_label) | ||
| training_sample: TrainingSample | ||
| if self.gnn: | ||
| x, edge_index, number_of_gates = create_dag(qc) | ||
| y = torch.tensor([[dev.description for dev in self.devices].index(target_label)], dtype=torch.float) | ||
| training_sample = (x, y, edge_index, number_of_gates, target_label) | ||
| else: | ||
| feature_vec = create_feature_vector(qc) | ||
| training_sample = (feature_vec, target_label) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Clarify or remove the intermediate y tensor that gets discarded.
The y tensor created at line 428 (device index) is immediately discarded in generate_training_data (line 328: x, _y, edge_idx, ...) and replaced with the full device scores tensor. This intermediate value adds confusion without serving a purpose.
Consider either:
- Not computing
yhere and returning a 4-tuple(x, edge_index, number_of_gates, target_label), or - Adding a comment explaining why this value exists but is unused.
if self.gnn:
x, edge_index, number_of_gates = create_dag(qc)
- y = torch.tensor([[dev.description for dev in self.devices].index(target_label)], dtype=torch.float)
- training_sample = (x, y, edge_index, number_of_gates, target_label)
+ # Note: y placeholder - actual targets are computed in generate_training_data
+ training_sample = (x, None, edge_index, number_of_gates, target_label)Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 421 to 432, the code
constructs a temporary tensor y (device index) that is never used downstream and
is discarded by generate_training_data; remove the creation of y and change the
training_sample for the GNN branch to a 4-tuple (x, edge_index, number_of_gates,
target_label) to match the rest of the pipeline, or if you must keep it for
clarity/compatibility add a one-line comment explaining it is unused and only
created for legacy reasons; ensure any call sites or tuple unpacking
expectations are updated accordingly to avoid breaking consumers.
| x_train, x_val, _y_train, _y_val = train_test_split( | ||
| training_data.X_train, training_data.y_train, test_size=0.2, random_state=5 | ||
| ) | ||
| # Dataloader | ||
| train_loader = DataLoader(x_train, batch_size=16, shuffle=True) | ||
|
|
||
| val_loader = DataLoader(x_val, batch_size=16, shuffle=False) | ||
| train_model( | ||
| model, | ||
| train_loader, | ||
| optimizer, | ||
| loss_fn, | ||
| task=task, | ||
| num_epochs=num_epochs, | ||
| device=device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: Double train/val split and incorrect device type.
-
Lines 674-676: Data is split again (80/20) after
_get_prepared_training_dataalready performed a 70/30 split. This results in effective train/val/test proportions of 56%/14%/30%. If intentional, consider documenting; otherwise, consider usingtraining_data.X_testas validation set. -
Line 688:
train_modelexpectsdevice: str | None, but atorch.deviceobject is passed. This works due to duck typing but is technically a type mismatch. Consider passingstr(device)or"cuda" if torch.cuda.is_available() else "cpu".
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 674 to 688, remove the
redundant train/val split (currently splitting training_data.X_train again) and
instead use the prepared validation set (e.g., training_data.X_test) or
explicitly document why a second split is required; update the DataLoader
instantiation to use the chosen train and val arrays accordingly. Also change
the device argument passed to train_model from a torch.device object to a string
(e.g., str(device) or a conditional "cuda"/"cpu" value) so the parameter matches
the expected type.
…into gnn-branch
…into gnn-branch
…into gnn-branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/device_selection/test_predictor_ml.py (1)
75-78: Missinggnnparameter in prediction call will cause test failure for GNN cases.When the test is parameterized with
gnn=True, the training creates a GNN model (.pthfile), butpredict_device_for_figure_of_meritis called withoutgnn=True. This causes the prediction to look for a non-existent Random Forest model (.joblibfile), resulting inFileNotFoundError.test_qc = get_benchmark("ghz", BenchmarkLevel.ALG, 3) - predicted = predict_device_for_figure_of_merit(test_qc, figure_of_merit="expected_fidelity") + predicted = predict_device_for_figure_of_merit(test_qc, figure_of_merit="expected_fidelity", gnn=gnn)
♻️ Duplicate comments (10)
src/mqt/predictor/ml/predictor.py (8)
673-679: Double train/validation split reduces effective training data.
_get_prepared_training_dataalready performs a 70/30 train/test split. Lines 673-675 split again with 80/20, resulting in effective proportions of 56%/14%/30%. Consider usingtraining_data.X_testas validation set if this is unintentional.
628-632: Critical:mlp_unitsnot stored in JSON; prediction will fail withKeyError.Line 845 reads
json_dict["mlp_units"], but this key is never added tojson_dict. The code computesmlp_unitsfrommlp_strbut doesn't store it. Additionally, class labels should be stored for GNN prediction.json_dict = study.best_trial.params mlp_str = dict_best_hyper["mlp"] mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")] json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 + json_dict["mlp_units"] = mlp_units + json_dict["classes"] = [dev.description for dev in self.devices]
633-662: Reduce code duplication in GNN instantiation.The two GNN construction blocks are nearly identical, differing only in
output_dim. Consolidate using a computedoutput_dim.- if self.figure_of_merit != "hellinger_distance": - model = GNN( - in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), - # ... many identical lines ... - output_dim=len(self.devices), - # ... - ).to("cuda" if torch.cuda.is_available() else "cpu") - else: - model = GNN( - # ... identical lines ... - output_dim=1, - # ... - ).to("cuda" if torch.cuda.is_available() else "cpu") + output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices) + model = GNN( + in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), + num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"], + hidden_dim=dict_best_hyper["hidden_dim"], + num_resnet_layers=dict_best_hyper["num_resnet_layers"], + mlp_units=mlp_units, + output_dim=output_dim, + dropout_p=dict_best_hyper["dropout"], + bidirectional=dict_best_hyper["bidirectional"], + use_sag_pool=dict_best_hyper["sag_pool"], + sag_ratio=0.7, + conv_activation=torch.nn.functional.leaky_relu, + mlp_activation=torch.nn.functional.leaky_relu, + ).to("cuda" if torch.cuda.is_available() else "cpu")
694-699: Verbose test block ignorestaskvariable and always uses classification evaluator.When
task="regression"(for Hellinger distance), callingevaluate_classification_modelproduces incorrect metrics. Branch ontaskto call the appropriate evaluator.if verbose: test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False) - avg_loss_test, dict_results, _ = evaluate_classification_model( - model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose - ) + if task == "regression": + avg_loss_test, dict_results, _ = evaluate_regression_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) + else: + avg_loss_test, dict_results, _ = evaluate_classification_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")
858-867: Critical:gnn_model.classesdoesn't exist and output shape mismatch will cause failures.Multiple issues will cause runtime failures:
Line 858:
gnn_model.classesis never defined on theGNNclass and is not set during training. This raisesAttributeError.Lines 862-864: For a single
Dataobject,gnn_model()returns tensor shape[1, num_devices].len(outputs)returns 1 (batch dimension), butlen(class_labels)would benum_devices. The assertion always fails.Line 866: Even if the assertion passed,
outputs.tolist()on[1, num_devices]produces[[score1, score2, ...]], sozipwould only iterate once.Fix by loading class labels from JSON (after storing them during training) and squeezing the output:
- class_labels = gnn_model.classes + class_labels = json_dict["classes"] # Must be saved during training with torch.no_grad(): outputs = gnn_model(feature_vector) + outputs = outputs.squeeze(0) # Remove batch dimension: [1, N] -> [N] assert class_labels is not None if len(class_labels) != len(outputs): msg = "outputs and class_labels must be same length" raise ValueError(msg)
426-429: Intermediateytensor computed but immediately discarded.The
ytensor at line 428 is created but discarded ingenerate_training_data(line 328 unpacks with_y). The actual targets are computed fromscoresingenerate_training_data. Consider removing this redundant computation or documenting why it exists.if self.gnn: x, edge_index, number_of_gates = create_dag(qc) - y = torch.tensor([[dev.description for dev in self.devices].index(target_label)], dtype=torch.float) - training_sample = (x, y, edge_index, number_of_gates, target_label) + # y placeholder - actual targets are computed in generate_training_data from scores + training_sample = (x, None, edge_index, number_of_gates, target_label)
505-506: Guard againstKFoldwith invalidn_splitsfor tiny datasets.If
k_folds(derived frommin(len(y_train), 5)) is less than 2,KFold(n_splits=k_folds)raisesValueError. For very small datasets, either skip cross-validation or clampk_folds.+ k_folds = max(2, k_folds) # KFold requires at least 2 splits kf = KFold(n_splits=k_folds, shuffle=True)
606-606: Dead expression: result of list comprehension is discarded.This line creates a list of device descriptions but never assigns it. If it was intended to store class labels for prediction, assign it to a variable and include it in
json_dict.- [dev.description for dev in self.devices] + # If needed for later use, assign the result: + # device_labels = [dev.description for dev in self.devices]Or remove the line entirely if not needed.
src/mqt/predictor/ml/helper.py (2)
304-317:get_results_classessilently corrupts target labels for [N,1] shaped tensors.After
targets.unsqueeze(1)at line 361, targets become shape[N, 1]. Whenargmax(targets, dim=1)is applied to[N, 1]tensors, it always returns 0 (since there's only one element along dim 1), regardless of the actual target values. This corrupts accuracy calculations.Apply this fix to handle both one-hot and single-column targets:
def get_results_classes(preds: torch.Tensor, targets: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Return predicted and target class indices. Arguments: preds: model predictions - targets: ground truth targets + targets: ground truth targets (one-hot encoded or integer class labels) Returns: pred_idx: predicted class indices targets_idx: target class indices """ pred_idx = torch.argmax(preds, dim=1) - targets_idx = torch.argmax(targets, dim=1) + if targets.dim() == 1: + targets_idx = targets.long() + elif targets.size(1) == 1: + targets_idx = targets.view(-1).long() + else: + targets_idx = torch.argmax(targets, dim=1) return pred_idx, targets_idx
331-331: Type annotation mismatch:metricsdict storesstrbut annotated asdict[str, float].Line 389 assigns
classification_report_res(a string) tometrics["classification_report"], but the return type annotation at line 331 specifiesdict[str, float].-) -> tuple[float, dict[str, float], tuple[np.ndarray, np.ndarray] | None]: +) -> tuple[float, dict[str, float | str], tuple[np.ndarray, np.ndarray] | None]:And update the initialization:
- metrics = {"loss": float(avg_loss)} + metrics: dict[str, float | str] = {"loss": float(avg_loss)}Also applies to: 388-389
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (3)
src/mqt/predictor/ml/helper.py(4 hunks)src/mqt/predictor/ml/predictor.py(12 hunks)tests/device_selection/test_predictor_ml.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/device_selection/test_predictor_ml.py (3)
src/mqt/predictor/ml/helper.py (1)
get_path_training_data(36-38)src/mqt/predictor/rl/helper.py (1)
get_path_training_data(97-99)src/mqt/predictor/ml/predictor.py (1)
Predictor(166-796)
src/mqt/predictor/ml/helper.py (1)
src/mqt/predictor/utils.py (1)
calc_supermarq_features(94-146)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py
42-42: Boolean-typed positional argument in function definition
(FBT001)
42-42: Boolean-typed positional argument in function definition
(FBT001)
113-113: Boolean-typed positional argument in function definition
(FBT001)
src/mqt/predictor/ml/predictor.py
107-107: Boolean-typed positional argument in function definition
(FBT001)
107-107: Boolean default positional argument in function definition
(FBT002)
146-146: Logging statement uses f-string
(G004)
149-149: Logging statement uses f-string
(G004)
173-173: Boolean-typed positional argument in function definition
(FBT001)
173-173: Boolean default positional argument in function definition
(FBT002)
448-448: Boolean-typed positional argument in function definition
(FBT001)
448-448: Boolean default positional argument in function definition
(FBT002)
570-570: Boolean-typed positional argument in function definition
(FBT001)
570-570: Boolean default positional argument in function definition
(FBT002)
800-800: Boolean-typed positional argument in function definition
(FBT001)
800-800: Boolean default positional argument in function definition
(FBT002)
⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
- GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
- GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (13)
tests/device_selection/test_predictor_ml.py (4)
38-43: LGTM!The parameterized test setup correctly covers both Random Forest and GNN paths with appropriate test IDs.
68-73: LGTM!The conditional assertions correctly verify the expected training artifacts for each model type.
95-99: LGTM!The cleanup logic correctly handles both
.npyand.ptfile extensions.
112-117: LGTM!Appropriate parameterization for testing both GNN and non-GNN paths in
_get_prepared_training_data.src/mqt/predictor/ml/helper.py (6)
53-57: LGTM!The helper correctly returns the path to the GNN checkpoint file using f-string formatting.
119-151: LGTM!The OpenQASM 3.0 gate list includes the spec URL reference and version note for future maintenance.
192-301: LGTM!The DAG construction correctly handles edge cases with empty edges and empty
topo_nodes, returning appropriate tensors.
405-469: LGTM!The regression evaluator correctly handles empty loaders and provides appropriate metrics.
472-571: LGTM!The training loop correctly handles early stopping, validation, and best model restoration.
574-585: LGTM!The
TrainingDatadataclass appropriately supports both classical (numpy) and GNN (graph) model data types.src/mqt/predictor/ml/predictor.py (3)
92-98: LGTM!The
TrainGNNKwargsTypedDict provides clear type hints for GNN training arguments.
100-163: LGTM!The
setup_device_predictorfunction correctly routes to GNN or Random Forest training based on thegnnflag.
327-341: LGTM!The GNN training sample construction correctly creates
Dataobjects with appropriate attributes.
src/mqt/predictor/ml/predictor.py
Outdated
| ).to("cuda" if torch.cuda.is_available() else "cpu") | ||
| gnn_model.load_state_dict(torch.load(path)) | ||
| x, edge_index, number_of_gates = create_dag(qc) | ||
| feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use standard PyTorch Geometric attribute num_nodes instead of custom num_gates.
PyTorch Geometric's Data class expects num_nodes as the standard attribute. Using num_gates may cause issues with some GNN operations that rely on the standard attribute.
- feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates)
+ feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around line 856, the Data object is created
with a nonstandard attribute num_gates; replace it with the PyTorch Geometric
standard num_nodes by setting num_nodes=number_of_gates (and if backward
compatibility is needed you can also keep num_gates=number_of_gates), so change
the Data(...) call to use num_nodes instead of num_gates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (8)
src/mqt/predictor/ml/helper.py (2)
304-317:get_results_classessilently corrupts targets when passed integer class labels.This function always applies
argmax(dim=1)on targets, which only works correctly for one-hot encoded targets. When targets are 1D integer labels or shape[N, 1],argmaxreturns all zeros, corrupting accuracy metrics. Perevaluate_classification_modelat line 360-361, targets are unsqueezed to[N, 1]when 1D, which will then fail here.Apply this diff to handle both encodings:
def get_results_classes(preds: torch.Tensor, targets: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Return predicted and target class indices. Arguments: preds: model predictions - targets: ground truth targets + targets: ground truth targets (one-hot encoded or integer class labels) Returns: pred_idx: predicted class indices targets_idx: target class indices """ pred_idx = torch.argmax(preds, dim=1) - targets_idx = torch.argmax(targets, dim=1) + if targets.dim() == 1: + targets_idx = targets.long() + elif targets.size(1) == 1: + targets_idx = targets.view(-1).long() + else: + targets_idx = torch.argmax(targets, dim=1) return pred_idx, targets_idx
323-331: Return type annotation is too narrow for the metrics dict.The function is annotated as returning
dict[str, float], butmetrics["classification_report"](line 389) is assigned a string fromclassification_report(). This will cause type checker warnings.Apply this diff:
-) -> tuple[float, dict[str, float], tuple[np.ndarray, np.ndarray] | None]: +) -> tuple[float, dict[str, float | str], tuple[np.ndarray, np.ndarray] | None]:And update line 375:
- metrics = {"loss": float(avg_loss)} + metrics: dict[str, float | str] = {"loss": float(avg_loss)}src/mqt/predictor/ml/predictor.py (6)
421-434: Intermediateytensor is computed but discarded.The
ytensor (device index) at line 428 is immediately discarded ingenerate_training_data(line 328:x, _y, edge_idx, ...) and replaced with the full device scores tensor. This intermediate computation serves no purpose and adds confusion.
500-506:num_outputsparameter is ignored;k_foldscan cause KFold error.Two issues:
Line 503 overwrites the
num_outputsparameter withmax(1, len(self.devices)), making the function argument useless.Line 506 creates
KFold(n_splits=k_folds), butk_foldscomes frommin(len(y_train), 5)at line 609. If dataset has only 1 sample,k_folds=1which is invalid forKFold.Apply this diff:
- num_outputs = max(1, len(self.devices)) + # Honor the parameter if provided, otherwise use device count + num_outputs = max(1, num_outputs) if num_outputs > 0 else max(1, len(self.devices)) # Split into k-folds + k_folds = max(2, k_folds) # KFold requires at least 2 splits kf = KFold(n_splits=k_folds, shuffle=True)
626-666: Critical:mlp_unitsandclassesnot stored in JSON; prediction will fail.Three issues will cause
KeyErrororAttributeErrorat prediction time:
Line 628-630:
json_dictis populated fromstudy.best_trial.paramswhich stores"mlp"as a string (e.g.,"64,32"), but prediction at line 845 readsjson_dict["mlp_units"]expecting a list.Line 632:
num_outputsis stored butclasses(device labels) are not. Line 858 tries to accessgnn_model.classeswhich doesn't exist.Lines 633-662: Duplicate GNN instantiation can be consolidated.
Apply this diff to store required metadata:
json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 + json_dict["mlp_units"] = mlp_units # Store parsed list for inference + json_dict["classes"] = [dev.description for dev in self.devices] # Store device labels + + # Consolidate model instantiation + output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices) + model = GNN( + in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), + num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"], + hidden_dim=dict_best_hyper["hidden_dim"], + num_resnet_layers=dict_best_hyper["num_resnet_layers"], + mlp_units=mlp_units, + output_dim=output_dim, + dropout_p=dict_best_hyper["dropout"], + bidirectional=dict_best_hyper["bidirectional"], + use_sag_pool=dict_best_hyper["sag_pool"], + sag_ratio=0.7, + conv_activation=torch.nn.functional.leaky_relu, + mlp_activation=torch.nn.functional.leaky_relu, + ).to("cuda" if torch.cuda.is_available() else "cpu") - if self.figure_of_merit != "hellinger_distance": - model = GNN( - ... - ).to("cuda" if torch.cuda.is_available() else "cpu") - else: - model = GNN( - ... - ).to("cuda" if torch.cuda.is_available() else "cpu")
687-699: Device type mismatch and wrong evaluator for regression tasks.Two issues:
Line 687:
train_modelexpectsdevice: str | None, butdeviceis atorch.deviceobject. This works due to duck typing but is a type mismatch.Lines 694-699: The verbose block always calls
evaluate_classification_model, even whentask="regression"(Hellinger distance). This produces meaningless metrics for regression.Apply this diff:
train_model( model, train_loader, optimizer, loss_fn, task=task, num_epochs=num_epochs, - device=device, + device=str(device), verbose=verbose, val_loader=val_loader, patience=30, min_delta=0.0, restore_best=True, ) if verbose: test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False) - avg_loss_test, dict_results, _ = evaluate_classification_model( - model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose - ) + if task == "regression": + avg_loss_test, dict_results, _ = evaluate_regression_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) + else: + avg_loss_test, dict_results, _ = evaluate_classification_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")
835-867: Critical: Multiple bugs will cause GNN prediction to fail at runtime.Four issues in the GNN prediction path:
Line 845:
json_dict["mlp_units"]will raiseKeyErrorsince training saves"mlp"as a string, not"mlp_units"as a list.Line 856: Uses
num_gatesbut PyTorch Geometric expectsnum_nodes. Inconsistent with training code at line 335.Line 858:
gnn_model.classesis never defined on theGNNclass. This will raiseAttributeError.Lines 860-864: Output shape mismatch. The model returns shape
[1, num_devices]for a single Data object, butlen(outputs)returns1(batch dimension), notnum_devices. The assertion will fail.Apply this diff:
gnn_model.load_state_dict(torch.load(path)) x, edge_index, number_of_gates = create_dag(qc) - feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates) + feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates) gnn_model.eval() - class_labels = gnn_model.classes + class_labels = json_dict["classes"] # Load from JSON metadata with torch.no_grad(): outputs = gnn_model(feature_vector) + outputs = outputs.squeeze(0) # Remove batch dimension [1, num_devices] -> [num_devices] assert class_labels is not None if len(class_labels) != len(outputs): msg = "outputs and class_labels must be same length" raise ValueError(msg) pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)Also fix the model instantiation to parse
mlpfrom JSON:gnn_model = GNN( in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), num_conv_wo_resnet=json_dict["num_conv_wo_resnet"], hidden_dim=json_dict["hidden_dim"], num_resnet_layers=json_dict["num_resnet_layers"], - mlp_units=json_dict["mlp_units"], + mlp_units=json_dict.get("mlp_units") or ([] if json_dict["mlp"] == "none" else [int(x) for x in json_dict["mlp"].split(",")]), output_dim=json_dict["num_outputs"],
606-606: Dead code: list comprehension result is discarded.Line 606 computes
[dev.description for dev in self.devices]but the result is not assigned to any variable. This appears to be leftover code.Remove the dead code:
- [dev.description for dev in self.devices]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (3)
src/mqt/predictor/ml/helper.py(4 hunks)src/mqt/predictor/ml/predictor.py(12 hunks)tests/device_selection/test_predictor_ml.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/mqt/predictor/ml/helper.py (1)
src/mqt/predictor/utils.py (1)
calc_supermarq_features(94-146)
tests/device_selection/test_predictor_ml.py (3)
src/mqt/predictor/ml/helper.py (1)
get_path_training_data(36-38)src/mqt/predictor/rl/helper.py (1)
get_path_training_data(97-99)src/mqt/predictor/ml/predictor.py (1)
Predictor(166-796)
🪛 Ruff (0.14.6)
src/mqt/predictor/ml/predictor.py
107-107: Boolean-typed positional argument in function definition
(FBT001)
107-107: Boolean default positional argument in function definition
(FBT002)
146-146: Logging statement uses f-string
(G004)
149-149: Logging statement uses f-string
(G004)
173-173: Boolean-typed positional argument in function definition
(FBT001)
173-173: Boolean default positional argument in function definition
(FBT002)
448-448: Boolean-typed positional argument in function definition
(FBT001)
448-448: Boolean default positional argument in function definition
(FBT002)
570-570: Boolean-typed positional argument in function definition
(FBT001)
570-570: Boolean default positional argument in function definition
(FBT002)
800-800: Boolean-typed positional argument in function definition
(FBT001)
800-800: Boolean default positional argument in function definition
(FBT002)
tests/device_selection/test_predictor_ml.py
42-42: Boolean-typed positional argument in function definition
(FBT001)
42-42: Boolean-typed positional argument in function definition
(FBT001)
113-113: Boolean-typed positional argument in function definition
(FBT001)
⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
- GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
- GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (14)
tests/device_selection/test_predictor_ml.py (4)
38-43: Parameterized test setup looks correct.The test parameterization covers the three key scenarios: RF-only, GNN-only, and GNN with verbose output. This provides good coverage for the dual training paths.
57-64: Test invocation correctly passes gnn and verbose parameters.The
setup_device_predictorcall properly forwards thegnnandverboseparameters to test both training paths.
68-73: Conditional assertions correctly validate expected artifacts.The test properly checks for
graph_dataset_expected_fidelity.ptwhengnn=Trueand the numpy-based files whengnn=False. This aligns with the dual data storage approach.
112-117: Parameterized test correctly covers both gnn and rf paths for error handling.The test now validates that
FileNotFoundErroris raised for both the RF and GNN paths when training data is missing.src/mqt/predictor/ml/helper.py (6)
53-58: GNN checkpoint path helper looks correct.The helper properly constructs the path to the GNN checkpoint file using an f-string as suggested in prior review.
119-152: OpenQASM 3.0 gate list is well-documented.The gate list includes a reference to the OpenQASM 3.0 specification and a note about verification. This is helpful for future maintenance.
192-301: DAG featurization handles edge cases correctly.The
create_dagfunction now includes proper guards for:
- Empty edges (returns
torch.empty((2, 0), dtype=torch.long))- Empty
topo_nodes(returns early with zero critical flags)The feature layout is well-documented in the docstring.
405-469: Regression evaluator is well-structured.The function correctly handles edge cases (empty predictions, constant targets for R²), uses proper metric computation, and has consistent verbose output.
472-572: Training loop with early stopping is correctly implemented.The
train_modelfunction properly:
- Handles both classification and regression tasks
- Implements early stopping with patience and min_delta
- Restores best model weights when
restore_best=True- Transfers model to the correct device
574-585: TrainingData dataclass correctly supports both numpy and graph data.The widened type hints (
NDArray | list[Data]for X,NDArray | Tensorfor y) properly accommodate both classical and GNN training paths.src/mqt/predictor/ml/predictor.py (4)
18-63: Imports are well-organized for the new GNN functionality.The new imports for
torch,optuna,torch_geometric, and related modules are properly structured with TYPE_CHECKING guards where appropriate.
83-98: Type definitions clearly document the sample formats.The
GNNSample,FeatureSample, andTrainingSampletype aliases make the dual data paths explicit. TheTrainGNNKwargsTypedDict is a clean way to forward GNN-specific arguments.
100-163: Setup function correctly routes to GNN or RF training paths.The
setup_device_predictorfunction properly:
- Accepts
gnnflag andgnn_kwargs- Routes to
train_gnn_modelwhengnn=True- Maintains backward compatibility with RF path
327-341: GNN training sample creation correctly populates Data object.The code properly creates a
torch_geometric.data.Dataobject with node features, edge index, and target labels from the score dictionary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/device_selection/test_predictor_ml.py (1)
75-78: Prediction test always uses RF path regardless ofgnnparameter.Line 76 calls
predict_device_for_figure_of_meritwithout passinggnn=gnn. Whengnn=True, the test trains a GNN model but then attempts prediction with the RF model (which may not exist), causing a mismatch.test_qc = get_benchmark("ghz", BenchmarkLevel.ALG, 3) - predicted = predict_device_for_figure_of_merit(test_qc, figure_of_merit="expected_fidelity") + predicted = predict_device_for_figure_of_merit(test_qc, figure_of_merit="expected_fidelity", gnn=gnn) assert predicted.description == "ibm_falcon_127"
♻️ Duplicate comments (10)
src/mqt/predictor/ml/predictor.py (9)
107-108: Consider makinggnnkeyword-only to satisfy Ruff FBT001/FBT002.Boolean positional arguments can be confusing at call sites. Adding
*,beforegnnforces callers to use the keyword form.def setup_device_predictor( devices: list[Target], figure_of_merit: figure_of_merit = "expected_fidelity", path_uncompiled_circuits: Path | None = None, path_compiled_circuits: Path | None = None, path_training_data: Path | None = None, timeout: int = 600, + *, gnn: bool = False, **gnn_kwargs: Unpack[TrainGNNKwargs], ) -> bool:
421-429: The intermediateytensor (line 428) is computed but discarded.The
ytensor created here is unpacked as_yand ignored ingenerate_training_data(line 328), then replaced with a different tensor. This computation is wasted.Also, line 421 assigns a dict to
scores_list, which is misleading naming.- scores_list = scores # list(scores.values()) + scores_dict = scores target_label = max(scores, key=lambda k: scores[k]) qc = QuantumCircuit.from_qasm_file(path_uncompiled_circuit / file) training_sample: TrainingSample if self.gnn: x, edge_index, number_of_gates = create_dag(qc) - y = torch.tensor([[dev.description for dev in self.devices].index(target_label)], dtype=torch.float) - training_sample = (x, y, edge_index, number_of_gates, target_label) + training_sample = (x, None, edge_index, number_of_gates, target_label)
503-506:num_outputsparameter is overwritten;k_foldsmay be invalid for tiny datasets.
Line 503 overwrites the
num_outputsparameter passed to the function, making it effectively unused. Either remove it from the signature or honor the passed value.Line 506 creates
KFold(n_splits=k_folds), butk_foldscomes frommin(len(training_data.y_train), 5)intrain_gnn_model(line 609), which could be 1 for a single sample.KFoldrequiresn_splits >= 2.- num_outputs = max(1, len(self.devices)) + # Use passed num_outputs or fall back to device count + num_outputs = num_outputs if num_outputs > 0 else max(1, len(self.devices)) # Split into k-folds + k_folds = max(2, k_folds) # KFold requires at least 2 splits kf = KFold(n_splits=k_folds, shuffle=True)
626-632: Critical:mlp_unitsandclassesnot stored in JSON metadata; prediction will fail.The
json_dictis populated fromstudy.best_trial.params, which contains"mlp"as a string (e.g.,"64,32"). However,predict_device_for_figure_of_meritat line 845 expectsjson_dict["mlp_units"]as a list, which will raiseKeyError.Additionally, class labels must be stored for inference since
gnn_model.classes(line 858) is never defined on the GNN class.json_dict = study.best_trial.params mlp_str = dict_best_hyper["mlp"] mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")] json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 + json_dict["mlp_units"] = mlp_units # Store parsed list for inference + json_dict["classes"] = [dev.description for dev in self.devices] # Store class labels
633-662: Reduce duplication in GNN model construction.The two branches are nearly identical, differing only in
output_dim. Consolidate:- if self.figure_of_merit != "hellinger_distance": - model = GNN( - in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), - ... - output_dim=len(self.devices), - ... - ).to("cuda" if torch.cuda.is_available() else "cpu") - else: - model = GNN( - ... - output_dim=1, - ... - ).to("cuda" if torch.cuda.is_available() else "cpu") + output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices) + device_str = "cuda" if torch.cuda.is_available() else "cpu" + model = GNN( + in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), + num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"], + hidden_dim=dict_best_hyper["hidden_dim"], + num_resnet_layers=dict_best_hyper["num_resnet_layers"], + mlp_units=mlp_units, + output_dim=output_dim, + dropout_p=dict_best_hyper["dropout"], + bidirectional=dict_best_hyper["bidirectional"], + use_sag_pool=dict_best_hyper["sag_pool"], + sag_ratio=0.7, + conv_activation=torch.nn.functional.leaky_relu, + mlp_activation=torch.nn.functional.leaky_relu, + ).to(device_str)
687-687: Type mismatch:train_modelexpectsdevice: str | None, but receivestorch.device.Line 687 passes
device(atorch.deviceobject) totrain_model, but the function signature expectsstr | None. This works due to duck typing but is a type annotation mismatch.- device=device, + device=str(device),
694-699: Verbose evaluation ignores task type; always uses classification evaluator.For
"hellinger_distance"or"estimated_hellinger_distance",taskis set to"regression"(line 601), but the verbose block unconditionally callsevaluate_classification_model. This produces incorrect metrics for regression tasks.if verbose: test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False) - avg_loss_test, dict_results, _ = evaluate_classification_model( - model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose - ) + if task == "regression": + avg_loss_test, dict_results, _ = evaluate_regression_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) + else: + avg_loss_test, dict_results, _ = evaluate_classification_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")
840-858: Critical: Multiple bugs will cause GNN prediction to fail at runtime.
Line 845:
json_dict["mlp_units"]will raiseKeyError. The training saves"mlp"(a string), not"mlp_units"(a list).Line 854:
torch.load(path)withoutweights_only=Trueraises a security warning in PyTorch 2.x.Line 856: Uses
num_gates=number_of_gates, but Data objects created during training usenum_nodes. While functionally equivalent, the inconsistent naming is confusing.Line 858:
gnn_model.classesis never defined on theGNNclass. This will raiseAttributeError.gnn_model = GNN( in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), num_conv_wo_resnet=json_dict["num_conv_wo_resnet"], hidden_dim=json_dict["hidden_dim"], num_resnet_layers=json_dict["num_resnet_layers"], - mlp_units=json_dict["mlp_units"], + mlp_units=json_dict.get("mlp_units") or ([] if json_dict["mlp"] == "none" else [int(x) for x in json_dict["mlp"].split(",")]), output_dim=json_dict["num_outputs"], dropout_p=json_dict["dropout"], bidirectional=json_dict["bidirectional"], use_sag_pool=json_dict["sag_pool"], sag_ratio=0.7, conv_activation=torch.nn.functional.leaky_relu, mlp_activation=torch.nn.functional.leaky_relu, ).to("cuda" if torch.cuda.is_available() else "cpu") - gnn_model.load_state_dict(torch.load(path)) + gnn_model.load_state_dict(torch.load(path, weights_only=True)) x, edge_index, number_of_gates = create_dag(qc) - feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates) + feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates) gnn_model.eval() - class_labels = gnn_model.classes + class_labels = json_dict.get("classes") # Load from JSON metadata
859-867: Output shape mismatch:len(outputs)returns batch dimension, not num_classes.The GNN model returns a tensor of shape
[1, num_devices]for a single Data object.len(outputs)returns 1 (the batch dimension), not the number of device scores. The comparison at line 862 will always fail unless there's exactly one device.Additionally,
outputs.tolist()on shape[1, N]returns[[score1, score2, ...]], causing the zip to iterate only once.with torch.no_grad(): outputs = gnn_model(feature_vector) + outputs = outputs.squeeze(0) # Remove batch dimension: [1, num_devices] -> [num_devices] assert class_labels is not None if len(class_labels) != len(outputs): msg = "outputs and class_labels must be same length" raise ValueError(msg) pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)src/mqt/predictor/ml/helper.py (1)
304-317:get_results_classescorrupts target indices for integer labels.When targets are 1D or
[N, 1]shaped (integer class indices),argmax(targets, dim=1)returns all zeros regardless of actual values. Inevaluate_classification_model, targets are unsqueezed to[N, 1]at lines 360-361, triggering this bug.Make the function robust to both encodings:
def get_results_classes(preds: torch.Tensor, targets: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Return predicted and target class indices. Arguments: preds: model predictions - targets: ground truth targets + targets: ground truth targets (one-hot encoded or integer class labels) Returns: pred_idx: predicted class indices targets_idx: target class indices """ pred_idx = torch.argmax(preds, dim=1) - targets_idx = torch.argmax(targets, dim=1) + if targets.dim() == 1: + targets_idx = targets.long() + elif targets.size(1) == 1: + targets_idx = targets.view(-1).long() + else: + targets_idx = torch.argmax(targets, dim=1) return pred_idx, targets_idx
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (3)
src/mqt/predictor/ml/helper.py(4 hunks)src/mqt/predictor/ml/predictor.py(12 hunks)tests/device_selection/test_predictor_ml.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/device_selection/test_predictor_ml.py (2)
src/mqt/predictor/ml/helper.py (1)
get_path_training_data(36-38)src/mqt/predictor/rl/helper.py (1)
get_path_training_data(97-99)
src/mqt/predictor/ml/predictor.py (4)
src/mqt/predictor/ml/gnn.py (1)
GNN(191-295)src/mqt/predictor/hellinger/utils.py (1)
get_hellinger_model_path(135-144)src/mqt/predictor/ml/helper.py (12)
TrainingData(575-585)create_dag(192-301)create_feature_vector(164-189)evaluate_classification_model(323-402)evaluate_regression_model(405-469)get_openqasm3_gates(119-151)get_path_trained_model(48-50)get_path_trained_model_gnn(53-57)get_path_training_circuits(60-62)get_path_training_circuits_compiled(65-67)get_path_training_data(36-38)train_model(472-571)src/mqt/predictor/rl/helper.py (3)
get_path_trained_model(102-104)get_path_training_circuits(107-109)get_path_training_data(97-99)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py
42-42: Boolean-typed positional argument in function definition
(FBT001)
42-42: Boolean-typed positional argument in function definition
(FBT001)
113-113: Boolean-typed positional argument in function definition
(FBT001)
src/mqt/predictor/ml/predictor.py
107-107: Boolean-typed positional argument in function definition
(FBT001)
107-107: Boolean default positional argument in function definition
(FBT002)
146-146: Logging statement uses f-string
(G004)
149-149: Logging statement uses f-string
(G004)
173-173: Boolean-typed positional argument in function definition
(FBT001)
173-173: Boolean default positional argument in function definition
(FBT002)
448-448: Boolean-typed positional argument in function definition
(FBT001)
448-448: Boolean default positional argument in function definition
(FBT002)
570-570: Boolean-typed positional argument in function definition
(FBT001)
570-570: Boolean default positional argument in function definition
(FBT002)
800-800: Boolean-typed positional argument in function definition
(FBT001)
800-800: Boolean default positional argument in function definition
(FBT002)
⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
- GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
- GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (11)
src/mqt/predictor/ml/predictor.py (3)
18-86: LGTM on imports and type definitions.The new imports for GNN support (torch, torch_geometric, optuna, gc, json) and the type aliases (
GNNSample,FeatureSample,TrainingSample) are well-organized and correctly typed.
92-98: LGTM on TrainGNNKwargs TypedDict.Good use of
TypedDictwithtotal=Falseto make all GNN-specific arguments optional. This cleanly separates GNN hyperparameters from the main function signature.
762-781: LGTM on conditional data loading for GNN vs RF paths.The branching logic correctly handles both numpy arrays for RF and PyTorch tensors for GNN, with appropriate file formats (.npy vs .pt) and data extraction methods.
tests/device_selection/test_predictor_ml.py (3)
38-43: Good test parameterization for RF and GNN paths.The parameterization with
ids=["rf", "gnn", "gnn_verbose"]provides clear test identification and covers both model types plus verbose output.
96-99: LGTM on cleanup handling both file formats.The cleanup correctly removes both
.npyand.ptfiles to handle artifacts from both RF and GNN test runs.
112-117: LGTM on parameterized test for_get_prepared_training_data.The test correctly parameterizes the
gnnflag and verifies thatFileNotFoundErroris raised when training data is missing for both paths.src/mqt/predictor/ml/helper.py (5)
53-57: LGTM onget_path_trained_model_gnnhelper.Clean implementation using f-string and clear docstring describing it returns a checkpoint file path.
119-151: LGTM on OpenQASM 3.0 gate list.Good documentation with spec reference and version snapshot note for future maintenance.
192-301: LGTM oncreate_dagimplementation.Robust handling of edge cases:
- Empty edges correctly produce a
(2, 0)tensor (lines 265-268)- Empty
topo_nodesreturns early with zero critical flags (lines 272-276)- Unknown gates raise a clear
ValueError(lines 247-249)The feature vector layout is well-documented in the docstring.
405-469: LGTM on regression evaluator.Clean implementation with proper handling of edge cases (empty results, constant targets for R² calculation). The verbose output provides useful diagnostic information.
574-585: LGTM on updatedTrainingDatadataclass.The widened type annotations correctly support both numpy arrays for RF and
torch_geometric.data.Datalists for GNN. The docstring clarifies the dual-mode support.
| def evaluate_classification_model( | ||
| model: nn.Module, | ||
| loader: torch_geometric.loader.DataLoader, | ||
| loss_fn: nn.Module, | ||
| device: str, | ||
| *, | ||
| return_arrays: bool = False, | ||
| verbose: bool = False, | ||
| ) -> tuple[float, dict[str, float], tuple[np.ndarray, np.ndarray] | None]: | ||
| """Evaluate a classification model with the given loss function and compute accuracy metrics. | ||
| Arguments: | ||
| model: classification model to be evaluated | ||
| loader: data loader for the evaluation dataset | ||
| loss_fn: loss function for evaluation | ||
| device: device to be used for evaluation (cuda or cpu) | ||
| return_arrays: whether to return prediction and target arrays | ||
| verbose: whether to print the metrics results. | ||
| Returns: | ||
| avg_loss: average loss over the loader | ||
| metrics: {"custom_accuracy": ..., "classification_report": ..., "mse": ..., "rmse": ..., "mae": ..., "r2": ...} | ||
| arrays: (preds, y_true) if return_arrays=True, else None. | ||
| """ | ||
| device = torch.device(device) | ||
|
|
||
| model.eval() | ||
| total_loss, total = 0.0, 0 | ||
| all_preds, all_targets = [], [] | ||
|
|
||
| with torch.no_grad(): | ||
| for batch in loader: | ||
| batch_device = batch.to(device) | ||
| preds = model(batch_device) | ||
| preds = torch.clamp(preds, 0.0, 1.0) | ||
| targets = batch_device.y.float() | ||
|
|
||
| if targets.dim() == 1: | ||
| targets = targets.unsqueeze(1) | ||
| if preds.shape != targets.shape: | ||
| msg = f"Shape mismatch: preds {preds.shape} vs targets {targets.shape}" | ||
| raise ValueError(msg) | ||
|
|
||
| bs = targets.size(0) | ||
| loss = loss_fn(preds, targets) | ||
| total_loss += loss.item() * bs | ||
| total += bs | ||
|
|
||
| all_preds.append(preds.detach().cpu()) | ||
| all_targets.append(targets.detach().cpu()) | ||
|
|
||
| avg_loss = total_loss / max(1, total) | ||
| metrics = {"loss": float(avg_loss)} | ||
|
|
||
| if not all_preds or not all_targets: | ||
| arrays = (np.array([]), np.array([])) if return_arrays else None | ||
| return avg_loss, metrics, arrays | ||
|
|
||
| preds = torch.cat(all_preds, dim=0) | ||
| targets = torch.cat(all_targets, dim=0) | ||
|
|
||
| # --- compute accuracy --- | ||
| pred_classes, target_classes = get_results_classes(preds, targets) | ||
| acc = accuracy_score(target_classes, pred_classes) | ||
| classification_report_res = classification_report(target_classes, pred_classes) | ||
| metrics["custom_accuracy"] = float(acc) | ||
| metrics["classification_report"] = classification_report_res | ||
|
|
||
| if verbose: | ||
| mse = mean_squared_error(targets.numpy().reshape(-1), preds.numpy().reshape(-1)) | ||
| mae = mean_absolute_error(targets.numpy().reshape(-1), preds.numpy().reshape(-1)) | ||
| rmse = float(np.sqrt(mse)) | ||
| if targets.size(0) < 2 or torch.all(targets == targets[0]): | ||
| r2 = float("nan") | ||
| else: | ||
| r2 = float(r2_score(targets.numpy().reshape(-1), preds.numpy().reshape(-1))) | ||
| metrics.update({"mse": float(mse), "rmse": float(rmse), "mae": float(mae), "r2": float(r2)}) | ||
|
|
||
| arrays = (preds.numpy(), targets.numpy()) if return_arrays else None | ||
| return avg_loss, metrics, arrays |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
LGTM on classification evaluator with empty loader guard.
The empty loader check (lines 377-379) prevents torch.cat crash on empty lists. The evaluation logic correctly computes accuracy and classification report.
One minor type annotation issue: the return type annotation is dict[str, float], but metrics["classification_report"] is a string. Consider widening to dict[str, float | str].
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/helper.py around lines 323 to 402, the return type
annotation currently declares the metrics dict as dict[str, float] but the
function stores a classification report string under "classification_report";
change the return type to reflect both numeric and string values (for example
dict[str, float | str] or dict[str, Any]) and update the function signature
annotation accordingly so the declared return type matches the actual metrics
contents.
| if val_loader is not None: | ||
| if task == "classification": | ||
| val_loss, val_metrics, _ = evaluate_classification_model( | ||
| model, val_loader, loss_fn, device=str(device), verbose=True | ||
| ) | ||
| elif task == "regression": | ||
| val_loss, val_metrics, _ = evaluate_regression_model( | ||
| model, val_loader, loss_fn, device=str(device), verbose=True | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Validation evaluation always uses verbose=True, causing noisy output.
Lines 537 and 541 hardcode verbose=True for validation evaluation, which prints metrics every epoch regardless of the outer verbose parameter. This may produce excessive output during training.
if task == "classification":
val_loss, val_metrics, _ = evaluate_classification_model(
- model, val_loader, loss_fn, device=str(device), verbose=True
+ model, val_loader, loss_fn, device=str(device), verbose=False
)
elif task == "regression":
val_loss, val_metrics, _ = evaluate_regression_model(
- model, val_loader, loss_fn, device=str(device), verbose=True
+ model, val_loader, loss_fn, device=str(device), verbose=False
)🤖 Prompt for AI Agents
In src/mqt/predictor/ml/helper.py around lines 534 to 542, the validation
evaluation calls hardcode verbose=True which forces per-epoch printing; change
both evaluate_classification_model and evaluate_regression_model calls to pass
the surrounding function's verbose parameter (e.g., verbose=verbose) so
validation logging respects the outer verbosity flag (or explicitly pass False
if you want to silence validation always).
| if self.gnn: | ||
| x, _y, edge_idx, n_nodes, target_label = training_sample | ||
| name_device = sorted(scores.keys()) | ||
| value_device = [scores[i] for i in name_device] | ||
| gnn_training_sample = Data( | ||
| x=x, | ||
| y=torch.tensor(value_device, dtype=torch.float32), | ||
| edge_index=edge_idx, | ||
| num_nodes=n_nodes, | ||
| target_label=target_label, | ||
| ) | ||
|
|
||
| training_data.append(gnn_training_sample if self.gnn else training_sample) | ||
| names_list.append(circuit_name) | ||
| scores_list.append(scores) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Device score ordering may diverge from self.devices order.
Line 329 uses sorted(scores.keys()) for device ordering, but self.devices is sorted by dev.description at initialization (line 189-191). If the sorted order of scores.keys() differs from the sorted order of device descriptions, the y tensor indices won't match the model's expected output ordering.
Consider using the canonical self.devices ordering:
if self.gnn:
x, _y, edge_idx, n_nodes, target_label = training_sample
- name_device = sorted(scores.keys())
- value_device = [scores[i] for i in name_device]
+ # Use canonical device ordering from self.devices
+ value_device = [scores.get(dev.description, -1.0) for dev in self.devices]
gnn_training_sample = Data(
x=x,
y=torch.tensor(value_device, dtype=torch.float32),There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (8)
src/mqt/predictor/ml/predictor.py (6)
633-662: Consolidate duplicate GNN instantiation.The two branches (lines 634-647 and 649-662) are nearly identical, differing only in
output_dim. This duplication can be eliminated.- if self.figure_of_merit != "hellinger_distance": - model = GNN( - in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), - ... - output_dim=len(self.devices), - ... - ).to("cuda" if torch.cuda.is_available() else "cpu") - else: - model = GNN( - ... - output_dim=1, - ... - ).to("cuda" if torch.cuda.is_available() else "cpu") + output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices) + device_str = "cuda" if torch.cuda.is_available() else "cpu" + model = GNN( + in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), + num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"], + hidden_dim=dict_best_hyper["hidden_dim"], + num_resnet_layers=dict_best_hyper["num_resnet_layers"], + mlp_units=mlp_units, + output_dim=output_dim, + dropout_p=dict_best_hyper["dropout"], + bidirectional=dict_best_hyper["bidirectional"], + use_sag_pool=dict_best_hyper["sag_pool"], + sag_ratio=0.7, + conv_activation=torch.nn.functional.leaky_relu, + mlp_activation=torch.nn.functional.leaky_relu, + ).to(device_str)
673-687: Double train/val split and type mismatch intrain_modelcall.
Double split:
_get_prepared_training_dataalready performs a 70/30 train/test split (line 784). Here, lines 673-675 split again (80/20), resulting in effective proportions of ~56%/14%/30%. If intentional, document this; otherwise usetraining_data.X_testas the validation set.Type mismatch: Line 687 passes
device(atorch.deviceobject) totrain_model, which expectsdevice: str | None. While this works due to duck typing, it's inconsistent.- device=device, + device=str(device),
854-867: Multiple critical bugs will cause GNN prediction to fail at runtime.
Line 856: Uses
num_gatesbut PyTorch Geometric'sDataclass expectsnum_nodes. Use the standard attribute name for consistency:-feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates) +feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)Line 858:
gnn_model.classesis never defined on theGNNclass and is never set during training. This will raiseAttributeError. Load class labels from JSON metadata instead (after addingjson_dict["classes"]during training as noted above):-class_labels = gnn_model.classes +class_labels = json_dict["classes"]Lines 862-864: The length comparison is incorrect.
outputshas shape[1, num_devices](batch dimension from single Data object), solen(outputs)returns 1, not the number of devices. The assertion will always fail unless there's exactly one device. Fix by squeezing the batch dimension:+outputs = outputs.squeeze(0) # Remove batch dim: [1, num_devices] -> [num_devices] assert class_labels is not None if len(class_labels) != len(outputs):
628-632:mlp_unitsnot stored in JSON; inference will fail withKeyError.Line 845 in
predict_device_for_figure_of_meritreadsjson_dict["mlp_units"], but here only the rawmlpstring is saved (viastudy.best_trial.params). The parsedmlp_unitslist must be stored explicitly.mlp_str = dict_best_hyper["mlp"] mlp_units = [] if mlp_str == "none" else [int(x) for x in mlp_str.split(",")] json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 + json_dict["mlp_units"] = mlp_units
503-506:num_outputsparameter is overwritten;k_foldsmay be invalid for tiny datasets.
Line 503 overwrites
num_outputswithmax(1, len(self.devices)), making the function parameter effectively unused. Either remove the parameter or honor it.Line 506 creates
KFold(n_splits=k_folds, ...)butk_foldscan be 1 whenlen(y_train) == 1, which is invalid forKFold. Guard against this:+ k_folds = max(2, k_folds) # KFold requires at least 2 splits kf = KFold(n_splits=k_folds, shuffle=True)And either remove the
num_outputsparameter from the signature or use it:- num_outputs = max(1, len(self.devices)) + if num_outputs is None or num_outputs < 1: + num_outputs = max(1, len(self.devices))
694-699: Verbose test block always calls classification evaluator regardless of task.When
task == "regression"(for Hellinger distance), lines 696-698 still callevaluate_classification_model, which computes accuracy metrics that are meaningless for regression. Branch on thetaskvariable.if verbose: test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False) - avg_loss_test, dict_results, _ = evaluate_classification_model( - model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose - ) + if task == "regression": + avg_loss_test, dict_results, _ = evaluate_regression_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) + else: + avg_loss_test, dict_results, _ = evaluate_classification_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")src/mqt/predictor/ml/helper.py (2)
304-317:get_results_classeswill corrupt target labels when targets are shaped[N, 1].This function applies
argmax(targets, dim=1)unconditionally. Whentargetsis shaped[N, 1](as happens aftertargets.unsqueeze(1)at line 361),argmaxalong dim=1 always returns 0 regardless of the actual class index. This causes accuracy metrics to be computed incorrectly.Apply this diff to handle both one-hot and integer label encodings:
def get_results_classes(preds: torch.Tensor, targets: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Return predicted and target class indices. Arguments: preds: model predictions - targets: ground truth targets + targets: ground truth targets (one-hot encoded or integer class labels) Returns: pred_idx: predicted class indices targets_idx: target class indices """ pred_idx = torch.argmax(preds, dim=1) - targets_idx = torch.argmax(targets, dim=1) + if targets.dim() == 1: + targets_idx = targets.long() + elif targets.size(1) == 1: + targets_idx = targets.view(-1).long() + else: + targets_idx = torch.argmax(targets, dim=1) return pred_idx, targets_idx
323-331: Return type annotation is too narrow formetricsdict.The function is annotated as returning
dict[str, float], butmetrics["classification_report"]at line 389 is assigned a string fromclassification_report(). This type mismatch will cause static type checkers to flag issues.-) -> tuple[float, dict[str, float], tuple[np.ndarray, np.ndarray] | None]: +) -> tuple[float, dict[str, float | str], tuple[np.ndarray, np.ndarray] | None]:And at line 375:
- metrics = {"loss": float(avg_loss)} + metrics: dict[str, float | str] = {"loss": float(avg_loss)}
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (3)
src/mqt/predictor/ml/helper.py(4 hunks)src/mqt/predictor/ml/predictor.py(12 hunks)tests/device_selection/test_predictor_ml.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/mqt/predictor/ml/predictor.py (6)
src/mqt/predictor/ml/gnn.py (1)
GNN(191-295)src/mqt/predictor/hellinger/utils.py (1)
get_hellinger_model_path(135-144)src/mqt/predictor/ml/helper.py (9)
TrainingData(575-585)create_dag(192-301)create_feature_vector(164-189)evaluate_classification_model(323-402)evaluate_regression_model(405-469)get_openqasm3_gates(119-151)get_path_trained_model(48-50)get_path_trained_model_gnn(53-57)train_model(472-571)tests/device_selection/test_predictor_ml.py (2)
path_uncompiled_circuits(27-29)path_compiled_circuits(33-35)tests/hellinger_distance/test_estimated_hellinger_distance.py (1)
device(49-51)tests/compilation/test_reward.py (1)
device(38-40)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py
42-42: Boolean-typed positional argument in function definition
(FBT001)
42-42: Boolean-typed positional argument in function definition
(FBT001)
113-113: Boolean-typed positional argument in function definition
(FBT001)
src/mqt/predictor/ml/predictor.py
107-107: Boolean-typed positional argument in function definition
(FBT001)
107-107: Boolean default positional argument in function definition
(FBT002)
146-146: Logging statement uses f-string
(G004)
149-149: Logging statement uses f-string
(G004)
173-173: Boolean-typed positional argument in function definition
(FBT001)
173-173: Boolean default positional argument in function definition
(FBT002)
448-448: Boolean-typed positional argument in function definition
(FBT001)
448-448: Boolean default positional argument in function definition
(FBT002)
570-570: Boolean-typed positional argument in function definition
(FBT001)
570-570: Boolean default positional argument in function definition
(FBT002)
800-800: Boolean-typed positional argument in function definition
(FBT001)
800-800: Boolean default positional argument in function definition
(FBT002)
⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
- GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
- GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (12)
tests/device_selection/test_predictor_ml.py (3)
38-43: Test parameterization covers both RF and GNN paths appropriately.The parameterization includes
gnn=False(RF),gnn=True, andgnn=Truewithverbose=True, providing good coverage for both training paths. The test IDs (rf,gnn,gnn_verbose) are clear and descriptive.
68-73: Artifact assertions correctly validate the expected output format per mode.The conditional checks appropriately verify:
- GNN mode produces
graph_dataset_*.pt- RF mode produces
training_data_*.npy,names_list_*.npy, andscores_list_*.npyThis aligns with the
generate_training_dataimplementation inpredictor.py.
112-115: LGTM!The parameterization correctly tests both RF and GNN paths for the
_get_prepared_training_datafailure case.src/mqt/predictor/ml/helper.py (6)
53-57: LGTM!The function correctly returns the path to the GNN checkpoint file using an f-string as previously suggested. The docstring accurately describes the return value.
119-151: LGTM!The OpenQASM 3.0 gate list includes the verification note and reference to the specification. The gates match the standard library documented at the referenced URL.
192-301: DAG featurization is well-structured with appropriate edge-case handling.The implementation correctly handles:
- Empty edge sets (lines 265-268)
- Empty
topo_nodes(lines 272-276)- Unknown gates with clear error messages (lines 247-249)
The feature layout is comprehensive: one-hot gate encoding, sin/cos parameters, arity, controls, num_params, critical path flag, and fan-out proportion.
405-469: LGTM!The regression evaluator correctly handles empty loaders, computes standard metrics (RMSE, MAE, R²), and guards against degenerate cases (constant targets, insufficient samples) by returning NaN.
472-572: LGTM!The training loop is well-structured with:
- Early stopping based on validation loss with patience and min_delta
- Best model state restoration
- Proper device handling
- Task-specific evaluation branching
574-585: LGTM!The
TrainingDatadataclass correctly documents support for both classical (numpy) and GNN (graph) models with appropriately widened type annotations.src/mqt/predictor/ml/predictor.py (3)
83-97: LGTM!The type aliases and
TrainGNNKwargsTypedDict are well-defined and provide clear type information for the GNN training pipeline.
326-341: LGTM!The GNN training sample construction correctly:
- Extracts node features and edges from the DAG
- Creates device score tensor from sorted device names
- Uses
num_nodes(notnum_gates) for the Data object
752-796: LGTM!The
_get_prepared_training_datamethod correctly handles both RF (numpy arrays) and GNN (torch Data objects) paths, with appropriate file loading and data preparation logic.
| scores_list = scores # list(scores.values()) | ||
| target_label = max(scores, key=lambda k: scores[k]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Remove or clarify the commented assignment.
Line 421 assigns scores_list = scores with a commented-out alternative # list(scores.values()). This suggests uncertainty about the intended value. Since scores is already a dict and is used directly, the comment should be removed to avoid confusion.
- scores_list = scores # list(scores.values())
+ scores_list = scores🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 421 to 422, remove the
ambiguous commented assignment and clarify the intended variable: replace or
delete the line assigning scores_list = scores with a single clear statement
(either remove scores_list entirely if unused, or set scores_list =
list(scores.values()) if a list of values is required) and remove the commented
alternative so the code reflects the actual intended type and usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (14)
tests/device_selection/test_predictor_ml.py (2)
68-73: GNN path assertions are incomplete.The GNN path only asserts
graph_dataset_expected_fidelity.ptexists, but according to the predictor code (lines 350-353),names_list_*.npyandscores_list_*.npyare saved regardless of thegnnflag. Add these assertions for the GNN branch to ensure complete validation.if gnn: assert (data_path / "graph_dataset_expected_fidelity.pt").exists() + assert (data_path / "names_list_expected_fidelity.npy").exists() + assert (data_path / "scores_list_expected_fidelity.npy").exists() else:
98-98: Useinoperator for cleaner suffix check.- if file.suffix == ".npy" or file.suffix == ".pt": + if file.suffix in (".npy", ".pt"):src/mqt/predictor/ml/predictor.py (12)
107-108: Consider keyword-only forgnnparameter.Making
gnnkeyword-only (after*) would prevent accidental positional usage and address the Ruff FBT001/FBT002 hints.
327-341: Device score ordering may diverge fromself.devicesorder.Line 329 uses
sorted(scores.keys())for device ordering, butself.devicesis sorted bydev.descriptionat initialization. Ifscores.keys()contains different keys or sorts differently than device descriptions, theytensor indices won't match the model's expected output ordering.Use the canonical
self.devicesordering:if self.gnn: x, _y, edge_idx, n_nodes, target_label = training_sample - name_device = sorted(scores.keys()) - value_device = [scores[i] for i in name_device] + # Use canonical device ordering from self.devices + value_device = [scores.get(dev.description, -1.0) for dev in self.devices] gnn_training_sample = Data(
421-422: Remove commented alternative.The commented-out
# list(scores.values())adds confusion. Sincescoresis used directly as a dict, remove the comment.
426-429: The intermediateytensor is computed but discarded.Line 428 computes
y = torch.tensor(...)but it's immediately discarded ingenerate_training_data(line 328:x, _y, edge_idx, ...). Either remove this computation or document why it exists.
606-606: Dead code: list comprehension result not assigned.Line 606 computes
[dev.description for dev in self.devices]but doesn't assign the result. This should store device labels injson_dict["classes"]for use during inference.- [dev.description for dev in self.devices] + json_dict["classes"] = [dev.description for dev in self.devices]
628-636:mlp_unitsnot stored in JSON metadata; inference will fail.Line 849 in
predict_device_for_figure_of_meritreadsjson_dict["mlp_units"], but the JSON only contains"mlp"(a string). Store the parsed list:json_dict["num_outputs"] = ( len(self.devices) if self.figure_of_merit != "hellinger_distance" and self.figure_of_merit != "estimated_hellinger_distance" else 1 ) + json_dict["mlp_units"] = mlp_units # Store parsed list for inference
637-666: Reduce code duplication in GNN instantiation.The two GNN instantiation blocks differ only in
output_dim. Consolidate them.+ output_dim = 1 if self.figure_of_merit in ("hellinger_distance", "estimated_hellinger_distance") else len(self.devices) + device_str = "cuda" if torch.cuda.is_available() else "cpu" - if self.figure_of_merit != "hellinger_distance" and self.figure_of_merit != "estimated_hellinger_distance": - model = GNN( - ... - output_dim=len(self.devices), - ... - ).to("cuda" if torch.cuda.is_available() else "cpu") - else: - model = GNN( - ... - output_dim=1, - ... - ).to("cuda" if torch.cuda.is_available() else "cpu") + model = GNN( + in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), + num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"], + hidden_dim=dict_best_hyper["hidden_dim"], + num_resnet_layers=dict_best_hyper["num_resnet_layers"], + mlp_units=mlp_units, + output_dim=output_dim, + dropout_p=dict_best_hyper["dropout"], + bidirectional=dict_best_hyper["bidirectional"], + use_sag_pool=dict_best_hyper["sag_pool"], + sag_ratio=0.7, + conv_activation=torch.nn.functional.leaky_relu, + mlp_activation=torch.nn.functional.leaky_relu, + ).to(device_str)
677-691: Double train/val split and incorrect device type.
Lines 677-679: Data is split again (80/20) after
_get_prepared_training_dataalready performed a 70/30 split. This results in effective train/val/test proportions of 56%/14%/30%. If intentional, document it; otherwise usetraining_data.X_testas validation.Line 691:
train_modelexpectsdevice: str | None, but atorch.deviceobject is passed. Passstr(device)or the string directly.
860-860: Use standard PyTorch Geometric attributenum_nodesinstead ofnum_gates.PyTorch Geometric's
Dataclass expectsnum_nodesas the standard attribute. Usingnum_gatesmay cause issues with GNN operations.- feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates) + feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates)
862-871: Critical: GNN prediction will fail at runtime due to multiple issues.
Line 862:
gnn_model.classesis never defined on theGNNclass or set during training. This will raiseAttributeError. Load from JSON metadata instead (after fixing training to save it).Lines 864-868: Shape mismatch. The model returns tensor of shape
[1, num_devices]for a single Data object.len(outputs)returns 1 (batch dimension), notnum_devices, so this assertion will always fail for multi-device scenarios.Line 870:
outputs.tolist()with shape[1, N]becomes[[score1, score2, ...]]. Zipping this withclass_labelsiterates only once.Fix by loading classes from JSON and squeezing the output:
- class_labels = gnn_model.classes + class_labels = json_dict["classes"] # Must be saved during training with torch.no_grad(): outputs = gnn_model(feature_vector) + outputs = outputs.squeeze(0) # Remove batch dimension [1, num_devices] -> [num_devices] assert class_labels is not None if len(class_labels) != len(outputs): msg = "outputs and class_labels must be same length" raise ValueError(msg) pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)
698-703: Verbose evaluation ignores task type.The verbose block unconditionally calls
evaluate_classification_model, but for"hellinger_distance"or"estimated_hellinger_distance", the task is"regression". Branch ontask:if verbose: test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False) - avg_loss_test, dict_results, _ = evaluate_classification_model( - model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose - ) + if task == "regression": + avg_loss_test, dict_results, _ = evaluate_regression_model( + model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose + ) + else: + avg_loss_test, dict_results, _ = evaluate_classification_model( + model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose + ) print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")
503-506:num_outputsparameter is ignored;k_foldsmay be invalid for tiny datasets.
Line 503 overwrites
num_outputswithmax(1, len(self.devices)), making the parameter useless. Either remove it from the signature or honor the passed value.
KFold(n_splits=k_folds)withk_folds=1(possible iflen(training_data.y_train) == 1) will raise an error. Guard against this:- num_outputs = max(1, len(self.devices)) + num_outputs = num_outputs if num_outputs > 0 else max(1, len(self.devices)) # Split into k-folds + k_folds = max(2, k_folds) # KFold requires at least 2 splits kf = KFold(n_splits=k_folds, shuffle=True)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (2)
src/mqt/predictor/ml/predictor.py(13 hunks)tests/device_selection/test_predictor_ml.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/device_selection/test_predictor_ml.py (3)
src/mqt/predictor/ml/helper.py (1)
get_path_training_data(36-38)src/mqt/predictor/rl/helper.py (1)
get_path_training_data(97-99)src/mqt/predictor/ml/predictor.py (1)
predict_device_for_figure_of_merit(803-878)
src/mqt/predictor/ml/predictor.py (2)
src/mqt/predictor/ml/gnn.py (1)
GNN(191-295)src/mqt/predictor/ml/helper.py (12)
TrainingData(575-585)create_dag(192-301)create_feature_vector(164-189)evaluate_classification_model(323-402)evaluate_regression_model(405-469)get_openqasm3_gates(119-151)get_path_trained_model(48-50)get_path_trained_model_gnn(53-57)get_path_training_circuits(60-62)get_path_training_circuits_compiled(65-67)get_path_training_data(36-38)train_model(472-571)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py
42-42: Boolean-typed positional argument in function definition
(FBT001)
42-42: Boolean-typed positional argument in function definition
(FBT001)
113-113: Boolean-typed positional argument in function definition
(FBT001)
src/mqt/predictor/ml/predictor.py
107-107: Boolean-typed positional argument in function definition
(FBT001)
107-107: Boolean default positional argument in function definition
(FBT002)
146-146: Logging statement uses f-string
(G004)
149-149: Logging statement uses f-string
(G004)
173-173: Boolean-typed positional argument in function definition
(FBT001)
173-173: Boolean default positional argument in function definition
(FBT002)
448-448: Boolean-typed positional argument in function definition
(FBT001)
448-448: Boolean default positional argument in function definition
(FBT002)
570-570: Boolean-typed positional argument in function definition
(FBT001)
570-570: Boolean default positional argument in function definition
(FBT002)
804-804: Boolean-typed positional argument in function definition
(FBT001)
804-804: Boolean default positional argument in function definition
(FBT002)
⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
- GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
🔇 Additional comments (6)
tests/device_selection/test_predictor_ml.py (2)
38-43: Test parametrization covers key GNN paths.The parametrization with
(False, False),(True, False),(True, True)appropriately tests RF mode, GNN mode, and GNN with verbose output. Consider adding(False, True)to also test RF with verbose mode for completeness, though this is optional.
112-116: Test coverage for GNN prepared training data path is good.The parametrization ensures both RF and GNN paths are tested for the
_get_prepared_training_dataerror handling. The test correctly verifies thatFileNotFoundErroris raised when training data is missing for both model types.src/mqt/predictor/ml/predictor.py (4)
92-98: Well-defined TypedDict for GNN kwargs.The
TrainGNNKwargsTypedDict withtotal=Falsecleanly captures optional GNN training parameters. This enables type-safe forwarding via**gnn_kwargs.
144-149: Conditional training path logic is correct.The branching between
train_random_forest_model()andtrain_gnn_model(**gnn_kwargs)properly routes based on thegnnflag, with appropriate logging for each path.
766-773: GNN training data loading logic is correct.The conditional loading of
.ptvs.npyfiles based onself.gnnflag properly handles both data formats. Theweights_only=Falsefortorch.loadis necessary here since the file containsDataobjects, not just tensors.
858-858: Useweights_only=Truewhen loading torch checkpoints for security.PyTorch's official documentation confirms that
torch.load()withoutweights_only=Truepermits arbitrary code execution from untrusted model files through pickle deserialization. For loading state_dict (plain tensor state),weights_only=Trueis the recommended secure approach. PyTorch changed the default toweights_only=Truestarting in version 2.6 due to these security concerns.- gnn_model.load_state_dict(torch.load(path)) + gnn_model.load_state_dict(torch.load(path, weights_only=True))
| # Figure out outputs and save path | ||
| if self.figure_of_merit == "hellinger_distance" or self.figure_of_merit == "estimated_hellinger_distance": | ||
| if len(self.devices) != 1: | ||
| msg = "A single device must be provided for Hellinger distance model training." | ||
| raise ValueError(msg) | ||
| num_outputs = 1 | ||
| save_mdl_path = str(get_hellinger_model_path(self.devices[0], gnn=True)) | ||
| else: | ||
| num_outputs = max(1, len(self.devices)) | ||
| save_mdl_path = str(get_path_trained_model_gnn(self.figure_of_merit)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Hellinger distance figure_of_merit check can be simplified.
The repeated or checks for "hellinger_distance" and "estimated_hellinger_distance" appear multiple times. Consider using a tuple membership test for clarity:
- if self.figure_of_merit == "hellinger_distance" or self.figure_of_merit == "estimated_hellinger_distance":
+ if self.figure_of_merit in ("hellinger_distance", "estimated_hellinger_distance"):This pattern appears at lines 585, 600, 632-634, 637, and 734. A constant or helper could further reduce duplication.
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around lines 584 to 593, the conditional
that checks for "hellinger_distance" or "estimated_hellinger_distance" should be
simplified and deduplicated: replace the repeated `== "hellinger_distance" or ==
"estimated_hellinger_distance"` checks with a membership test like `if
self.figure_of_merit in ("hellinger_distance", "estimated_hellinger_distance"):`
and add a small module-level constant or helper function (e.g.,
HELLINGER_FIGURES = ("hellinger_distance", "estimated_hellinger_distance") or
def is_hellinger(f): return f in HELLINGER_FIGURES) then use that
constant/helper to replace the same pattern at the other occurrences you noted
(lines ~585, 600, 632-634, 637, and 734) so the logic is identical but not
duplicated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/hellinger_distance/test_estimated_hellinger_distance.py (1)
254-265: Windows timeout warning is incorrectly gated bygnnflag.The warning should appear regardless of whether GNN is used. The condition
sys.platform == "win32" and not gnnmeans the warning is only expected for RF on Windows, but the timeout limitation applies to both paths:- if sys.platform == "win32" and not gnn: + if sys.platform == "win32":If the GNN path intentionally doesn't use timeout, add a comment explaining why.
♻️ Duplicate comments (15)
src/mqt/predictor/ml/predictor.py (12)
604-606: Dead code: list comprehension result is discarded.Line 606 computes
[dev.description for dev in self.devices]but does not assign the result. This appears to be an incomplete attempt to store class labels injson_dict. Either remove this line or assign it appropriately.- [dev.description for dev in self.devices] + # Class labels are stored below in json_dict["class_labels"]Or simply delete line 606 since line 665 already stores the class labels correctly.
633-662: Duplicate GNN instantiation code can be consolidated.The two branches differ only in
output_dim. This was flagged before and marked as addressed, but the duplication remains.- if self.figure_of_merit != "hellinger_distance": - model = GNN( - in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), - ... - output_dim=len(self.devices), - ... - ).to("cuda" if torch.cuda.is_available() else "cpu") - else: - model = GNN( - ... - output_dim=1, - ... - ).to("cuda" if torch.cuda.is_available() else "cpu") + output_dim = 1 if self.figure_of_merit == "hellinger_distance" else len(self.devices) + device_str = "cuda" if torch.cuda.is_available() else "cpu" + model = GNN( + in_feats=int(len(get_openqasm3_gates()) + 1 + 6 + 3 + 1 + 1), + num_conv_wo_resnet=dict_best_hyper["num_conv_wo_resnet"], + hidden_dim=dict_best_hyper["hidden_dim"], + num_resnet_layers=dict_best_hyper["num_resnet_layers"], + mlp_units=mlp_units, + output_dim=output_dim, + dropout_p=dict_best_hyper["dropout"], + bidirectional=dict_best_hyper["bidirectional"], + use_sag_pool=dict_best_hyper["sag_pool"], + sag_ratio=0.7, + conv_activation=torch.nn.functional.leaky_relu, + mlp_activation=torch.nn.functional.leaky_relu, + ).to(device_str)
674-676: Double train/val split results in unexpected data proportions.
_get_prepared_training_data()already performs a 70/30 train/test split. Line 674-676 then splits the training set again 80/20, resulting in effective proportions of ~56%/14%/30% (train/val/test). If intentional, document it; otherwise, consider usingtraining_data.X_testas the validation set.
688-688: Type mismatch:train_modelexpectsdevice: str | Nonebut receivestorch.device.Line 688 passes a
torch.deviceobject where a string is expected. While this works due to duck typing, it's technically incorrect. Pass a string instead:- device=device, + device=str(device),
695-700: Verbose evaluation always uses classification metrics, even for regression tasks.When
task == "regression"(Hellinger distance), the verbose block still callsevaluate_classification_model, producing incorrect metrics. Branch ontask:if verbose: test_loader = DataLoader(training_data.X_test, batch_size=16, shuffle=False) - avg_loss_test, dict_results, _ = evaluate_classification_model( - model, test_loader, loss_fn=loss_fn, device=device, verbose=verbose - ) + if task == "regression": + avg_loss_test, dict_results, _ = evaluate_regression_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) + else: + avg_loss_test, dict_results, _ = evaluate_classification_model( + model, test_loader, loss_fn=loss_fn, device=str(device), verbose=verbose + ) print(f"Test loss: {avg_loss_test:.4f}, {dict_results}")
860-862: Use standard PyTorch Geometric attributenum_nodesinstead ofnum_gates.
Dataobjects expectnum_nodesas the standard attribute name. Usingnum_gatesmay cause issues with some GNN operations that rely on the standard attribute.- feature_vector = Data(x=x, edge_index=edge_index, num_gates=number_of_gates).to( + feature_vector = Data(x=x, edge_index=edge_index, num_nodes=number_of_gates).to( "cuda" if torch.cuda.is_available() else "cpu" )
866-873: Output tensor shape mismatch will cause assertion failure.The GNN model returns a tensor of shape
[1, num_devices]for a singleDataobject.len(outputs)returns 1 (batch dimension), notnum_devices, so the assertionlen(class_labels) != len(outputs)will always fail for multi-device cases.Squeeze the batch dimension before comparison:
with torch.no_grad(): outputs = gnn_model(feature_vector) + outputs = outputs.squeeze(0) # Remove batch dimension: [1, num_devices] -> [num_devices] assert class_labels is not None if len(class_labels) != len(outputs): msg = "outputs and class_labels must be same length" raise ValueError(msg) - pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True) + pairs = sorted(zip(outputs.tolist(), class_labels, strict=False), reverse=True)
327-337: Device score ordering may diverge fromself.devicesorder.Line 329 uses
sorted(scores.keys())for device ordering. If this differs from the sorted order ofself.devices(sorted bydev.descriptionat line 189-191), theytensor indices won't match the model's expected output ordering. Consider using the canonicalself.devicesordering:if self.gnn: x, _y, edge_idx, n_nodes, target_label = training_sample - name_device = sorted(scores.keys()) - value_device = [scores[i] for i in name_device] + # Use canonical device ordering from self.devices + value_device = [scores.get(dev.description, -1.0) for dev in self.devices]
421-422: Remove ambiguous commented code.The commented alternative
# list(scores.values())suggests uncertainty about the intended value. Sincescoresis used directly as a dict, remove the comment.- scores_list = scores # list(scores.values()) + scores_list = scores
628-632:mlp_unitslist is not stored in JSON metadata; inference will fail.Line 841-842 in
predict_device_for_figure_of_meritexpects to parsejson_dict["mlp"]to reconstructmlp_units, which works. However,json_dict["num_outputs"]is stored but the parsedmlp_unitslist itself is not. This is fine as long as the parsing logic matches exactly. Consider addingmlp_unitsexplicitly for robustness:json_dict["num_outputs"] = len(self.devices) if self.figure_of_merit != "hellinger_distance" else 1 + json_dict["mlp_units"] = mlp_units # Store parsed list for clarity
503-503: Parameternum_outputsis immediately overwritten, rendering it unused.Line 503 overwrites
num_outputs = max(1, len(self.devices))regardless of the value passed as an argument. Either removenum_outputsfrom the function signature and docstring, or honor the argument when provided (e.g., use a sentinel value likeNone).- num_outputs: int, + num_outputs: int | None = None, ... - num_outputs = max(1, len(self.devices)) + if num_outputs is None: + num_outputs = max(1, len(self.devices))
506-506:KFold(n_splits=1)is invalid and will raise an error for very small datasets.When
len(training_data.y_train) == 1,k_foldsbecomes 1, andKFold(n_splits=1)raises aValueError. Guard against this:- k_folds = min(len(training_data.y_train), 5) + k_folds = max(2, min(len(training_data.y_train), 5))Alternatively, skip hyperparameter search entirely for datasets with fewer than 2 samples.
tests/hellinger_distance/test_estimated_hellinger_distance.py (3)
293-324: Cleanup logic as a test is fragile and order-dependent.
test_remove_filesacts as teardown but depends on test execution order. If tests run in parallel or a subset is selected via-k, cleanup may not run or may run prematurely.Consider using a module-scoped autouse fixture with
yield:@pytest.fixture(scope="module", autouse=True) def cleanup_test_artifacts(source_path: Path, target_path: Path): """Cleanup test artifacts after all tests in module complete.""" yield # Let tests run # Cleanup logic here... if source_path.exists(): for file in source_path.iterdir(): if file.suffix == ".qasm": file.unlink() source_path.rmdir() # ... rest of cleanup
307-317: Duplicatedata_pathcomputation can be consolidated.Lines 307-311 and 313-317 both compute
data_pathand iterate over it. Consolidate into a single loop:data_path = get_path_training_data() / "training_data_aggregated" if data_path.exists(): for file in data_path.iterdir(): - if file.suffix == ".npy": - file.unlink() - - data_path = get_path_training_data() / "training_data_aggregated" - if data_path.exists(): - for file in data_path.iterdir(): - if file.suffix == ".pt": + if file.suffix in (".npy", ".pt"): file.unlink()
216-220: Docstring is outdated and doesn't reflect parametrized test behavior.The docstring mentions "the Hellinger distance model that was trained in the previous test", but this test now parametrizes over RF vs GNN and trains internally. Update to reflect current behavior:
- """Test the entire predictor toolchain with the Hellinger distance model that was trained in the previous test.""" + """Test the full predictor toolchain for `estimated_hellinger_distance` for both RF and GNN backends."""
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (2)
src/mqt/predictor/ml/predictor.py(12 hunks)tests/hellinger_distance/test_estimated_hellinger_distance.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/mqt/predictor/ml/predictor.py (2)
src/mqt/predictor/ml/gnn.py (1)
GNN(191-295)src/mqt/predictor/ml/helper.py (7)
TrainingData(575-585)create_dag(192-301)create_feature_vector(164-189)get_openqasm3_gates(119-151)get_path_trained_model(48-50)get_path_trained_model_gnn(53-57)get_path_training_data(36-38)
tests/hellinger_distance/test_estimated_hellinger_distance.py (1)
src/mqt/predictor/ml/helper.py (3)
TrainingData(575-585)create_dag(192-301)get_path_training_data(36-38)
🪛 Ruff (0.14.6)
src/mqt/predictor/ml/predictor.py
107-107: Boolean-typed positional argument in function definition
(FBT001)
107-107: Boolean default positional argument in function definition
(FBT002)
146-146: Logging statement uses f-string
(G004)
149-149: Logging statement uses f-string
(G004)
173-173: Boolean-typed positional argument in function definition
(FBT001)
173-173: Boolean default positional argument in function definition
(FBT002)
448-448: Boolean-typed positional argument in function definition
(FBT001)
448-448: Boolean default positional argument in function definition
(FBT002)
570-570: Boolean-typed positional argument in function definition
(FBT001)
570-570: Boolean default positional argument in function definition
(FBT002)
801-801: Boolean-typed positional argument in function definition
(FBT001)
801-801: Boolean default positional argument in function definition
(FBT002)
tests/hellinger_distance/test_estimated_hellinger_distance.py
158-158: Boolean-typed positional argument in function definition
(FBT001)
218-218: Boolean-typed positional argument in function definition
(FBT001)
🔇 Additional comments (1)
src/mqt/predictor/ml/predictor.py (1)
83-86: Type alias definitions look good.The type aliases
GNNSample,FeatureSample, andTrainingSampleprovide clear documentation of the expected tuple structures for both training paths.
| conv_activation=torch.nn.functional.leaky_relu, | ||
| mlp_activation=torch.nn.functional.leaky_relu, | ||
| ).to("cuda" if torch.cuda.is_available() else "cpu") | ||
| gnn_model.load_state_dict(torch.load(path)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing weights_only=True for torch.load poses a security risk.
torch.load(path) without weights_only=True can execute arbitrary code if the model file is malicious. Since you're only loading state dicts:
- gnn_model.load_state_dict(torch.load(path))
+ gnn_model.load_state_dict(torch.load(path, weights_only=True))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| gnn_model.load_state_dict(torch.load(path)) | |
| gnn_model.load_state_dict(torch.load(path, weights_only=True)) |
🤖 Prompt for AI Agents
In src/mqt/predictor/ml/predictor.py around line 858, replace the insecure
torch.load(path) call with torch.load(path, weights_only=True) so only tensor
weights are loaded; if your runtime device matters also pass map_location (e.g.,
map_location=device) and then feed the resulting state dict into
gnn_model.load_state_dict; optionally wrap the load in a try/except to surface a
clear error message on failure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (5)
tests/hellinger_distance/test_estimated_hellinger_distance.py (5)
186-192:num_nodesshould be an integer, not a tensor.PyTorch Geometric's
Dataclass expectsnum_nodesto be a plain Pythonint. Wrapping it intorch.tensor([n_nodes])creates a 1D tensor, which can cause unexpected behavior during batching or graph operations.gnn_training_sample = Data( x=x, y=torch.tensor(labels_list[i], dtype=torch.float32), edge_index=edge_idx, - num_nodes=torch.tensor([n_nodes]), + num_nodes=n_nodes, )
219-221: Docstring is outdated after parameterization.The docstring still references "the Hellinger distance model that was trained in the previous test," but this test now trains its own model and is parameterized for both RF and GNN backends. Update to reflect the current behavior.
def test_train_and_qcompile_with_hellinger_model( source_path: Path, target_path: Path, device: Target, gnn: bool ) -> None: - """Test the entire predictor toolchain with the Hellinger distance model that was trained in the previous test.""" + """Test the entire predictor toolchain for estimated_hellinger_distance with both RF and GNN backends."""
307-317: Consolidate duplicatedata_pathcomputation.
data_pathis computed twice with the same value. Merge the cleanup loops to avoid redundant computation.data_path = get_path_training_data() / "training_data_aggregated" if data_path.exists(): for file in data_path.iterdir(): - if file.suffix == ".npy": + if file.suffix in (".npy", ".pt"): file.unlink() - - data_path = get_path_training_data() / "training_data_aggregated" - if data_path.exists(): - for file in data_path.iterdir(): - if file.suffix == ".pt": - file.unlink()
319-323: Use tuple membership instead of chainedorcomparisons.For readability and consistency, use
inwith a tuple.- if file.suffix == ".joblib" or file.suffix == ".pth" or file.suffix == ".json": + if file.suffix in (".joblib", ".pth", ".json"):
293-324: Consider converting cleanup test to a fixture.
test_remove_filesacts as teardown for artifacts created by other tests. This couples correctness to test ordering and can break under parallel execution or when running a subset of tests with-k. Consider moving this logic to a module-scoped autouse fixture withyieldto guarantee cleanup regardless of test selection or execution order.Example fixture approach:
@pytest.fixture(scope="module", autouse=True) def cleanup_test_artifacts(source_path: Path, target_path: Path): """Clean up test artifacts after all tests in the module complete.""" yield # Cleanup logic here...
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (1)
tests/hellinger_distance/test_estimated_hellinger_distance.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/hellinger_distance/test_estimated_hellinger_distance.py (3)
src/mqt/predictor/hellinger/utils.py (2)
calc_device_specific_features(37-132)hellinger_distance(29-34)src/mqt/predictor/ml/helper.py (3)
TrainingData(575-585)create_dag(192-301)get_path_training_data(36-38)src/mqt/predictor/rl/helper.py (1)
get_path_training_data(97-99)
🪛 Ruff (0.14.6)
tests/hellinger_distance/test_estimated_hellinger_distance.py
158-158: Boolean-typed positional argument in function definition
(FBT001)
218-218: Boolean-typed positional argument in function definition
(FBT001)
⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
- GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
- GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
🔇 Additional comments (2)
tests/hellinger_distance/test_estimated_hellinger_distance.py (2)
21-31: LGTM!The new imports for
torch,torch_geometric.data(Batch, Data), andcreate_dagare appropriate for the GNN testing functionality.
254-265: Clarify the Windows + GNN timeout behavior.The condition
sys.platform == "win32" and not gnnmeans GNN tests on Windows skip the timeout warning path. If this is intentional (e.g., GNN compilation doesn't use the same timeout mechanism), a brief comment would help future maintainers understand this distinction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: ASSERTIVE
Plan: Pro
📒 Files selected for processing (2)
tests/device_selection/test_predictor_ml.py(4 hunks)tests/hellinger_distance/test_estimated_hellinger_distance.py(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/device_selection/test_predictor_ml.py (2)
src/mqt/predictor/ml/helper.py (1)
get_path_training_data(36-38)src/mqt/predictor/ml/predictor.py (2)
predict_device_for_figure_of_merit(800-880)Predictor(166-797)
🪛 Ruff (0.14.6)
tests/device_selection/test_predictor_ml.py
42-42: Boolean-typed positional argument in function definition
(FBT001)
42-42: Boolean-typed positional argument in function definition
(FBT001)
115-115: Boolean-typed positional argument in function definition
(FBT001)
tests/hellinger_distance/test_estimated_hellinger_distance.py
158-158: Boolean-typed positional argument in function definition
(FBT001)
219-219: Boolean-typed positional argument in function definition
(FBT001)
⏰ Context from checks skipped due to timeout of 900000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: 🐍 Test (macos-14) / 🐍 macos-14
- GitHub Check: 🐍 Test (windows-2022) / 🐍 windows-2022
- GitHub Check: 🐍 Test (ubuntu-24.04) / 🐍 ubuntu-24.04
🔇 Additional comments (6)
tests/device_selection/test_predictor_ml.py (3)
38-43: Parameterization covers key scenarios appropriately.The test matrix exercises RF-only, GNN without verbose, and GNN with verbose modes. This provides good coverage for the new GNN path integration.
68-75: Artifact assertions correctly differentiate GNN and RF paths.The conditional logic now properly validates:
- GNN path:
graph_dataset_expected_fidelity.ptplus the shared.npyfiles- RF path:
training_data_expected_fidelity.npyplus the shared.npyfilesThis addresses the earlier feedback about asserting
names_listandscores_listexistence for the GNN path.
114-117: Parameterized test correctly exercises both predictor backends.The GNN flag is properly passed to
Predictor, ensuring both RF and GNN backends are tested for the_get_prepared_training_dataerror path.tests/hellinger_distance/test_estimated_hellinger_distance.py (3)
166-193: GNN training data construction is correct.The branching logic properly:
- Uses
calc_device_specific_featuresfor RF andcreate_dagfor GNN- Constructs appropriate
TrainingDataobjects for each path- Correctly passes
n_nodesas an integer toData(addressing earlier feedback)
196-214: Model training and evaluation paths are properly separated.The conditional training logic correctly dispatches to
train_gnn_modelvstrain_random_forest_model, and the evaluation uses appropriate methods for each model type. The tolerance comment at line 213 explains the rationale foratol=2e-1.
255-266: Clarify whether Windows timeout warning should apply to GNN path.The condition
sys.platform == "win32" and not gnnmeans the timeout warning test is skipped whengnn=Trueon Windows. Is this intentional because the GNN path handles timeouts differently, or should both paths warn on Windows?If GNN compilation also uses timeouts that aren't supported on Windows, consider:
- if sys.platform == "win32" and not gnn: + if sys.platform == "win32":Otherwise, add a comment explaining why GNN is exempt from this warning.
| @pytest.mark.parametrize("gnn", [False, True], ids=["rf", "gnn"]) | ||
| def test_train_model_and_predict(device: Target, gnn: bool) -> None: | ||
| """Test the training of the random forest regressor. The trained model is saved and used in the following tests.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Docstring no longer accurately reflects parameterized test behavior.
The docstring still says "Test the training of the random forest regressor" but the test is now parameterized to cover both RF and GNN backends.
@pytest.mark.parametrize("gnn", [False, True], ids=["rf", "gnn"])
def test_train_model_and_predict(device: Target, gnn: bool) -> None:
- """Test the training of the random forest regressor. The trained model is saved and used in the following tests."""
+ """Test model training and prediction for both RF and GNN backends."""🧰 Tools
🪛 Ruff (0.14.6)
158-158: Boolean-typed positional argument in function definition
(FBT001)
🤖 Prompt for AI Agents
In tests/hellinger_distance/test_estimated_hellinger_distance.py around lines
157 to 159, the test docstring incorrectly states it only tests training of the
random forest regressor while the test is parameterized to run for both RF and
GNN; update the docstring to accurately describe that the test trains and saves
a model for both backends (random forest and GNN) and that the saved model is
used in subsequent tests, and briefly document the gnn parameter meaning (False
=> RF, True => GNN).
| def test_remove_files(source_path: Path, target_path: Path) -> None: | ||
| """Remove files created during testing.""" | ||
| if source_path.exists(): | ||
| for file in source_path.iterdir(): | ||
| if file.suffix == ".qasm": | ||
| file.unlink() | ||
| source_path.rmdir() | ||
|
|
||
| if target_path.exists(): | ||
| for file in target_path.iterdir(): | ||
| if file.suffix == ".qasm": | ||
| file.unlink() | ||
| target_path.rmdir() | ||
|
|
||
| data_path = get_path_training_data() / "training_data_aggregated" | ||
| if data_path.exists(): | ||
| for file in data_path.iterdir(): | ||
| if file.suffix in (".npy", ".pt"): | ||
| file.unlink() | ||
|
|
||
| model_path = get_path_training_data() / "trained_model" | ||
| if model_path.exists(): | ||
| for file in model_path.iterdir(): | ||
| if file.suffix in (".joblib", ".pth", ".json"): | ||
| file.unlink() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧹 Nitpick | 🔵 Trivial
Consider converting cleanup to a fixture for test isolation.
test_remove_files acts as teardown but depends on test execution order. If tests are run in isolation or in parallel, cleanup may not occur. This was flagged in a previous review.
Convert to a session-scoped autouse fixture:
@pytest.fixture(scope="module", autouse=True)
def cleanup_test_artifacts(source_path: Path, target_path: Path):
"""Clean up test artifacts after all tests in the module complete."""
yield # Let tests run
# Cleanup logic here...This ensures cleanup runs regardless of test selection or ordering.
🤖 Prompt for AI Agents
In tests/hellinger_distance/test_estimated_hellinger_distance.py around lines
294 to 318, replace the standalone test_remove_files teardown function with a
pytest fixture that runs automatically for the module (or session) and performs
cleanup after tests by yielding control and then executing the same removal
logic; make the fixture use scope="module" (or "session") and autouse=True,
accept source_path and target_path as parameters, ensure the post-yield cleanup
is idempotent (check existence before unlink/rmdir or use safe removal) and
covers data_path and model_path cleanup so cleanup runs regardless of test
ordering or isolation.
Description
This PR introduces a Graph Neural Network (GNN) as an alternative to the Random Forest model for predicting the best device to run a quantum circuit.
To support this, the preprocessing pipeline was redesigned: instead of manually extracting features from the circuit, the model now directly takes as input the Directed Acyclic Graph (DAG) representation of the quantum circuit.
🚀 Major Changes
Graph Neural Network Integration
🎯 Motivation
🔧 Fixes and Enhancements
📦 Dependency Updates
optuna>=4.5.0torch-geometric>=2.6.1Checklist: