diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..fe4d85d0 Binary files /dev/null and b/.DS_Store differ diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..ad38ef22 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,18 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + + + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} \ No newline at end of file diff --git a/examples/.DS_Store b/examples/.DS_Store new file mode 100644 index 00000000..5a7a6938 Binary files /dev/null and b/examples/.DS_Store differ diff --git a/examples/automatic_temporal_classifier.py b/examples/automatic_temporal_classifier.py new file mode 100644 index 00000000..1a5de404 --- /dev/null +++ b/examples/automatic_temporal_classifier.py @@ -0,0 +1,77 @@ +import time + +import pyreason as pr +import torch +import torch.nn as nn +import networkx as nx +import numpy as np +import random +from datetime import timedelta + +seed_value = 65 # Good Gap Gap +# seed_value = 47 # Good Gap Good +# seed_value = 43 # Good Good Good +random.seed(seed_value) +np.random.seed(seed_value) +torch.manual_seed(seed_value) + + +def input_fn(): + return torch.rand(1, 3) # Dummy input function for the model + + +weld_model = nn.Linear(3, 2) +class_names = ["good", "gap"] + +# Define integration options: +# Only consider probabilities above 0.5, adjust lower bound for high confidence, and use a snap value. +interface_options = pr.ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=False, + snap_value=1.0 +) + +# Wrap the model using LogicIntegratedClassifier. +weld_quality_checker = pr.TemporalLogicIntegratedClassifier( + weld_model, + class_names, + identifier="weld_object", + interface_options=interface_options, + poll_interval=timedelta(seconds=0.5), + # poll_interval=1, + poll_condition="gap", + input_fn=input_fn, +) + +pr.add_rule(pr.Rule("repairing(weld_object) <-1 gap(weld_object)", "repair attempted rule")) +pr.add_rule(pr.Rule("defective(weld_object) <-1 gap(weld_object), repairing(weld_object)", "defective rule")) + +max_iters = 5 +for weld_iter in range(max_iters): + # Time step 1: Initial inspection shows the weld is good. + features = torch.rand(1, 3) # Values chosen to indicate a good weld. + t = pr.get_time() + logits, probs, classifier_facts = weld_quality_checker(features, t1=t, t2=t) + # print(f"=== Weld Inspection for Part: {weld_iter} ===") + # print("Logits:", logits) + # print("Probabilities:", probs) + for fact in classifier_facts: + pr.add_fact(fact) + + # Reasoning + pr.settings.atom_trace = True + pr.settings.verbose = False + again = False if weld_iter == 0 else True + interpretation = pr.reason(timesteps=1, again=again, restart=False) + trace = pr.get_rule_trace(interpretation) + print(f"\n=== Reasoning Rule Trace for Weld Part: {weld_iter} ===") + print(trace[0], "\n\n") + + time.sleep(5) + + # Check if part is defective + # if pr.get_logic_program().interp.query(pr.Query("defective(weld_object)")): + if interpretation.query(pr.Query("defective(weld_object)")): + print("Defective weld detected! \n Replacing the part.\n\n") + # break diff --git a/examples/blackjack.py b/examples/blackjack.py new file mode 100644 index 00000000..aec35d71 --- /dev/null +++ b/examples/blackjack.py @@ -0,0 +1,217 @@ +from datetime import timedelta +from pathlib import Path +import random +import shutil +import sys +import os +from time import sleep +import networkx as nx +import numba +import cv2 +import torch +from ultralytics import YOLO +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) +from pyreason.scripts.learning.classification.yolo_classifier import YoloLogicIntegratedTemporalClassifier +from pyreason.scripts.learning.classification.temporal_classifier import TemporalLogicIntegratedClassifier +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions +from pyreason.scripts.rules.rule import Rule +from pyreason.pyreason import _Settings as Settings, reason, reset_settings, get_rule_trace, add_fact, add_rule, load_graph, save_rule_trace, get_time, Query, add_annotation_function, get_logic_program + + +MODEL = YOLO('/Users/coltonpayne/dyuman-pyreason/pyreason/pyreason/train56/weights/best.pt') +TRAINING_IMAGES_DIR = "/Users/coltonpayne/dyuman-pyreason/pyreason/examples/images/cards" + +CARD_NUMBERS = ["A", "2", "3", "4", "5", "6", "7", "8", "9", "10", + "J", "Q", "K"] + +CARD_SUITS = ["h", "d", "c", "s"] + +CARD_NAMES = [ + f"{number}{suit}" for number in CARD_NUMBERS for suit in CARD_SUITS] + +MAX_POINTS = 42 + +CARD_VALUES = { + "A": 3, + "2": 6, + "3": 6, + "4": 6, + "5": 6, + "6": 6, + "7": 6, + "8": 6, + "9": 6, + "10": 6, + "J": 9, + "Q": 9, + "K": 9 +} + +# Move all images back to the training directory after a run +def reset_used_images(): + training_dir = Path(TRAINING_IMAGES_DIR) + used_dir = training_dir / "used" + + if not used_dir.exists(): + print("No 'used' folder found. Nothing to reset.") + return + + for file in used_dir.iterdir(): + if file.is_file(): + shutil.move(str(file), str(training_dir / file.name)) + +@numba.njit +def init_hand_annotation_fn(annotations, weights): + """ + Given all the cards in a players hand, return a decimal representation of the hand. + Each point value of a card is a digit in the decimal representation. + """ + row = annotations[0] + digits = 0 + num_digits = 0 + + for i in range(len(row)): + card_value = int(row[i].l * 10) + digits = digits * 10 + card_value + num_digits += 1 + + bound_value = digits / (10 ** num_digits) + print("Points in player hand: ", bound_value) + return bound_value, 1 + + +@numba.njit +def hand_percent_to_losing_annotation_fn(annotations, weights): + """ + Calculate the odds of a player losing based on their hand and the remaining deck, + Given a decimal representation of the players hand and knowlege of the cards in the full deck + """ + fractional = annotations[0][0].l + # First, we need to see how many cards are in the player's hand + num_player_cards = 52 + initial_num_player_cards = 0 + scale = 1 + for d in range(1, 52 + 1): + scale *= 10 + scaled = fractional * scale + int_check = abs(scaled - int(scaled)) + if int_check < 1e-8: # close enough to an int + initial_num_player_cards = d + break + num_player_cards = initial_num_player_cards + fractional *= 10 + player_card_array = [] + player_point_total = 0 + # Now, we need to get the total number of points in the players hand. + # We also make an array of the point total the player has so we can remove equivalent cards from the game deck + for val in range(num_player_cards): + digit = int(fractional) + player_point_total += digit + player_card_array.append(digit) + fractional -= int(fractional) + # Add back a small floating point amount to avoid floating point precision issues + fractional += 1e-8 + if fractional == 0.0: + break # no more digits left + fractional *= 10 + + print("Player Point Total: ", player_point_total) + total_bust_cards = 0 + row = annotations[1] + for i in range(len(row)): + # This gets the first item in this numpy object + card_value = row[i].l * 10 + if card_value in player_card_array: + player_card_array.remove(card_value) + continue + #print("Card Value: ", card_value) + if card_value + player_point_total > MAX_POINTS: + total_bust_cards += 1 + bust_odds = total_bust_cards / (len(row) - num_player_cards) + print("Odds of losing on next card draw: ", bust_odds) + if bust_odds >= 1: + print("Every remaining card will put you over the point total. The game is over!") + print("Player Final Score: ", player_point_total) + bust_odds = 1 + + return bust_odds, 1 + + +add_annotation_function(init_hand_annotation_fn) +add_annotation_function(hand_percent_to_losing_annotation_fn) + +# Use these helper functions to make facts and rules for all the cards in the deck. +def add_deck_holds_fact(card_name): + card_value = CARD_VALUES[card_name[:-1]] + lower_bound = card_value / 10 + add_fact(Fact(f"deck_holds({card_name}, full_deck): [{lower_bound}, {1}]", "deck_holds_fact")) + +def add_player_holds_rule(card_name): + card_value = CARD_VALUES[card_name[:-1]] + lower_bound = card_value / 10 + add_rule(Rule(f"player_holds_(_{card_name}): [{lower_bound}, 1] <-0 _{card_name}(card_drawn_obj)", f"player_holds_{card_name}_rule")) + + +# This is the format of fact that should be returned from the YOLO Classifier. Each time interval, should return a new fact like this. +# We start the game with the player holding one card. +print("Initializing game...") + +# Initialize the deck and player holds rules for each card in the deck. +for card in CARD_NAMES: + add_deck_holds_fact(card) + add_player_holds_rule(card) + +# add_fact(Fact("_2c(card_drawn_obj)", "_2c_drawn_fact")) +add_rule(Rule("hand_as_point_vals(hand) : init_hand_annotation_fn <-0 player_holds_(card):[0.1,1]", "hand_as_point_vals_rule")) +add_rule(Rule("odds_of_losing(hand) : hand_percent_to_losing_annotation_fn <-0 hand_as_point_vals(hand):[0,1], deck_holds(card, full_deck):[0.1,1]", "odds_of_losing_rule")) + +settings = Settings +settings.atom_trace = True +settings.verbose = False + +# Input function for temporal classifier +# For yolo models, the input function should return an image. +def input_function(): + random.seed() + available_images = [p for p in Path(TRAINING_IMAGES_DIR).glob("*") if p.is_file()] + if not available_images: + print("No images left.") + return None + image_path = random.choice(available_images) + print("Path:", image_path) + image = cv2.imread(str(image_path)) + # Move image to a "used" folder so it doesn't draw the same card again + shutil.move(str(image_path), str(Path(TRAINING_IMAGES_DIR) / "used" / image_path.name)) + return image + +interface_options = ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=False, + snap_value=1.0 +) + +card_drawn_object = YoloLogicIntegratedTemporalClassifier( + MODEL, + class_names=CARD_NAMES, + identifier="card_drawn_obj", + interface_options=interface_options, + poll_interval=timedelta(seconds=5), + input_fn=input_function, +) + +interpretation = reason() +logic_program = get_logic_program() +interp = logic_program.interp +for i in range(200): + print("Quering game end condition...") + result = interp.query(Query("odds_of_losing(hand)")) + if result: + print("Player can not draw any more cards without going over the point total. Ending game.") + break + sleep(1) + +# Save the rule trace and move all images from the 'used' directory. +save_rule_trace(interpretation) +reset_used_images() \ No newline at end of file diff --git a/examples/images/cards/ace_clubs.png b/examples/images/cards/ace_clubs.png new file mode 100644 index 00000000..f2c43c81 Binary files /dev/null and b/examples/images/cards/ace_clubs.png differ diff --git a/examples/images/cards/ace_diamonds.jpg b/examples/images/cards/ace_diamonds.jpg new file mode 100644 index 00000000..aa9ea86d Binary files /dev/null and b/examples/images/cards/ace_diamonds.jpg differ diff --git a/examples/images/cards/ace_hearts.jpeg b/examples/images/cards/ace_hearts.jpeg new file mode 100644 index 00000000..58e7c8ea Binary files /dev/null and b/examples/images/cards/ace_hearts.jpeg differ diff --git a/examples/images/cards/ace_spade.jpeg b/examples/images/cards/ace_spade.jpeg new file mode 100644 index 00000000..6a57e874 Binary files /dev/null and b/examples/images/cards/ace_spade.jpeg differ diff --git a/examples/images/cards/eight_clubs.jpeg b/examples/images/cards/eight_clubs.jpeg new file mode 100644 index 00000000..abd28971 Binary files /dev/null and b/examples/images/cards/eight_clubs.jpeg differ diff --git a/examples/images/cards/eight_diamonds.jpeg b/examples/images/cards/eight_diamonds.jpeg new file mode 100644 index 00000000..2e954f64 Binary files /dev/null and b/examples/images/cards/eight_diamonds.jpeg differ diff --git a/examples/images/cards/eight_hearts.jpeg b/examples/images/cards/eight_hearts.jpeg new file mode 100644 index 00000000..a97ed877 Binary files /dev/null and b/examples/images/cards/eight_hearts.jpeg differ diff --git a/examples/images/cards/eight_spades.jpeg b/examples/images/cards/eight_spades.jpeg new file mode 100644 index 00000000..1902ba9b Binary files /dev/null and b/examples/images/cards/eight_spades.jpeg differ diff --git a/examples/images/cards/five_clubs.jpg b/examples/images/cards/five_clubs.jpg new file mode 100644 index 00000000..431f1a89 Binary files /dev/null and b/examples/images/cards/five_clubs.jpg differ diff --git a/examples/images/cards/five_diamonds.jpeg b/examples/images/cards/five_diamonds.jpeg new file mode 100644 index 00000000..27491cd0 Binary files /dev/null and b/examples/images/cards/five_diamonds.jpeg differ diff --git a/examples/images/cards/five_hearts.jpeg b/examples/images/cards/five_hearts.jpeg new file mode 100644 index 00000000..8f7386b9 Binary files /dev/null and b/examples/images/cards/five_hearts.jpeg differ diff --git a/examples/images/cards/five_spades.jpeg b/examples/images/cards/five_spades.jpeg new file mode 100644 index 00000000..03aecbef Binary files /dev/null and b/examples/images/cards/five_spades.jpeg differ diff --git a/examples/images/cards/four_clubs.jpeg b/examples/images/cards/four_clubs.jpeg new file mode 100644 index 00000000..d4a06f58 Binary files /dev/null and b/examples/images/cards/four_clubs.jpeg differ diff --git a/examples/images/cards/four_diamonds.jpeg b/examples/images/cards/four_diamonds.jpeg new file mode 100644 index 00000000..2be4f935 Binary files /dev/null and b/examples/images/cards/four_diamonds.jpeg differ diff --git a/examples/images/cards/four_hearts.jpeg b/examples/images/cards/four_hearts.jpeg new file mode 100644 index 00000000..5cd68bac Binary files /dev/null and b/examples/images/cards/four_hearts.jpeg differ diff --git a/examples/images/cards/four_spades.jpeg b/examples/images/cards/four_spades.jpeg new file mode 100644 index 00000000..71cc8226 Binary files /dev/null and b/examples/images/cards/four_spades.jpeg differ diff --git a/examples/images/cards/jack_clubs.jpg b/examples/images/cards/jack_clubs.jpg new file mode 100644 index 00000000..46ac5adb Binary files /dev/null and b/examples/images/cards/jack_clubs.jpg differ diff --git a/examples/images/cards/jack_diamond.jpeg b/examples/images/cards/jack_diamond.jpeg new file mode 100644 index 00000000..aced1ce4 Binary files /dev/null and b/examples/images/cards/jack_diamond.jpeg differ diff --git a/examples/images/cards/jack_hearts.jpeg b/examples/images/cards/jack_hearts.jpeg new file mode 100644 index 00000000..63e170cc Binary files /dev/null and b/examples/images/cards/jack_hearts.jpeg differ diff --git a/examples/images/cards/jack_spades.jpg b/examples/images/cards/jack_spades.jpg new file mode 100644 index 00000000..a4ac7ccb Binary files /dev/null and b/examples/images/cards/jack_spades.jpg differ diff --git a/examples/images/cards/king_clubs.jpeg b/examples/images/cards/king_clubs.jpeg new file mode 100644 index 00000000..2728081f Binary files /dev/null and b/examples/images/cards/king_clubs.jpeg differ diff --git a/examples/images/cards/king_diamonds.jpeg b/examples/images/cards/king_diamonds.jpeg new file mode 100644 index 00000000..6cef771e Binary files /dev/null and b/examples/images/cards/king_diamonds.jpeg differ diff --git a/examples/images/cards/king_hearts.jpeg b/examples/images/cards/king_hearts.jpeg new file mode 100644 index 00000000..0fc5fc30 Binary files /dev/null and b/examples/images/cards/king_hearts.jpeg differ diff --git a/examples/images/cards/king_spades.jpeg b/examples/images/cards/king_spades.jpeg new file mode 100644 index 00000000..2c4afae6 Binary files /dev/null and b/examples/images/cards/king_spades.jpeg differ diff --git a/examples/images/cards/nine_clubs.jpeg b/examples/images/cards/nine_clubs.jpeg new file mode 100644 index 00000000..9dc9054c Binary files /dev/null and b/examples/images/cards/nine_clubs.jpeg differ diff --git a/examples/images/cards/nine_diamonds.jpeg b/examples/images/cards/nine_diamonds.jpeg new file mode 100644 index 00000000..66ae9655 Binary files /dev/null and b/examples/images/cards/nine_diamonds.jpeg differ diff --git a/examples/images/cards/nine_hearts.jpeg b/examples/images/cards/nine_hearts.jpeg new file mode 100644 index 00000000..c4a51186 Binary files /dev/null and b/examples/images/cards/nine_hearts.jpeg differ diff --git a/examples/images/cards/nine_spades.jpeg b/examples/images/cards/nine_spades.jpeg new file mode 100644 index 00000000..382765eb Binary files /dev/null and b/examples/images/cards/nine_spades.jpeg differ diff --git a/examples/images/cards/queen_clubs.jpeg b/examples/images/cards/queen_clubs.jpeg new file mode 100644 index 00000000..dbc2e6ce Binary files /dev/null and b/examples/images/cards/queen_clubs.jpeg differ diff --git a/examples/images/cards/queen_diamonds.jpeg b/examples/images/cards/queen_diamonds.jpeg new file mode 100644 index 00000000..76ac8881 Binary files /dev/null and b/examples/images/cards/queen_diamonds.jpeg differ diff --git a/examples/images/cards/queen_hearts.jpeg b/examples/images/cards/queen_hearts.jpeg new file mode 100644 index 00000000..d46b98f0 Binary files /dev/null and b/examples/images/cards/queen_hearts.jpeg differ diff --git a/examples/images/cards/queen_spades.jpeg b/examples/images/cards/queen_spades.jpeg new file mode 100644 index 00000000..cfbda2bb Binary files /dev/null and b/examples/images/cards/queen_spades.jpeg differ diff --git a/examples/images/cards/seven_clubs.jpeg b/examples/images/cards/seven_clubs.jpeg new file mode 100644 index 00000000..1a2453d9 Binary files /dev/null and b/examples/images/cards/seven_clubs.jpeg differ diff --git a/examples/images/cards/seven_diamonds.jpeg b/examples/images/cards/seven_diamonds.jpeg new file mode 100644 index 00000000..9e2f8c81 Binary files /dev/null and b/examples/images/cards/seven_diamonds.jpeg differ diff --git a/examples/images/cards/seven_hearts.jpeg b/examples/images/cards/seven_hearts.jpeg new file mode 100644 index 00000000..c6af72ed Binary files /dev/null and b/examples/images/cards/seven_hearts.jpeg differ diff --git a/examples/images/cards/seven_spades.jpeg b/examples/images/cards/seven_spades.jpeg new file mode 100644 index 00000000..632189e6 Binary files /dev/null and b/examples/images/cards/seven_spades.jpeg differ diff --git a/examples/images/cards/six_clubs.jpg b/examples/images/cards/six_clubs.jpg new file mode 100644 index 00000000..65d667c9 Binary files /dev/null and b/examples/images/cards/six_clubs.jpg differ diff --git a/examples/images/cards/six_diamonds.jpeg b/examples/images/cards/six_diamonds.jpeg new file mode 100644 index 00000000..37930ff2 Binary files /dev/null and b/examples/images/cards/six_diamonds.jpeg differ diff --git a/examples/images/cards/six_hearts.jpeg b/examples/images/cards/six_hearts.jpeg new file mode 100644 index 00000000..34549988 Binary files /dev/null and b/examples/images/cards/six_hearts.jpeg differ diff --git a/examples/images/cards/six_spades.jpg b/examples/images/cards/six_spades.jpg new file mode 100644 index 00000000..3b7ecdb9 Binary files /dev/null and b/examples/images/cards/six_spades.jpg differ diff --git a/examples/images/cards/ten_clubs.png b/examples/images/cards/ten_clubs.png new file mode 100644 index 00000000..62972648 Binary files /dev/null and b/examples/images/cards/ten_clubs.png differ diff --git a/examples/images/cards/ten_diamonds.jpeg b/examples/images/cards/ten_diamonds.jpeg new file mode 100644 index 00000000..c4e19a02 Binary files /dev/null and b/examples/images/cards/ten_diamonds.jpeg differ diff --git a/examples/images/cards/ten_hearts.jpeg b/examples/images/cards/ten_hearts.jpeg new file mode 100644 index 00000000..3c537b46 Binary files /dev/null and b/examples/images/cards/ten_hearts.jpeg differ diff --git a/examples/images/cards/ten_spades.jpeg b/examples/images/cards/ten_spades.jpeg new file mode 100644 index 00000000..cc76ca6b Binary files /dev/null and b/examples/images/cards/ten_spades.jpeg differ diff --git a/examples/images/cards/three_clubs.jpeg b/examples/images/cards/three_clubs.jpeg new file mode 100644 index 00000000..7373fa3a Binary files /dev/null and b/examples/images/cards/three_clubs.jpeg differ diff --git a/examples/images/cards/three_diamonds.jpeg b/examples/images/cards/three_diamonds.jpeg new file mode 100644 index 00000000..0cd3e25a Binary files /dev/null and b/examples/images/cards/three_diamonds.jpeg differ diff --git a/examples/images/cards/three_hearts.jpeg b/examples/images/cards/three_hearts.jpeg new file mode 100644 index 00000000..536f0919 Binary files /dev/null and b/examples/images/cards/three_hearts.jpeg differ diff --git a/examples/images/cards/three_spades.jpeg b/examples/images/cards/three_spades.jpeg new file mode 100644 index 00000000..0ada2bfd Binary files /dev/null and b/examples/images/cards/three_spades.jpeg differ diff --git a/examples/images/cards/two_diamonds.jpg b/examples/images/cards/two_diamonds.jpg new file mode 100644 index 00000000..aae885bb Binary files /dev/null and b/examples/images/cards/two_diamonds.jpg differ diff --git a/examples/images/cards/two_hearts.jpeg b/examples/images/cards/two_hearts.jpeg new file mode 100644 index 00000000..60287ba7 Binary files /dev/null and b/examples/images/cards/two_hearts.jpeg differ diff --git a/examples/images/cards/two_spades.jpeg b/examples/images/cards/two_spades.jpeg new file mode 100644 index 00000000..29547dad Binary files /dev/null and b/examples/images/cards/two_spades.jpeg differ diff --git a/examples/images/fish_1.jpeg b/examples/images/fish_1.jpeg new file mode 100644 index 00000000..6413569e Binary files /dev/null and b/examples/images/fish_1.jpeg differ diff --git a/examples/images/fish_2.jpeg b/examples/images/fish_2.jpeg new file mode 100644 index 00000000..581f1b17 Binary files /dev/null and b/examples/images/fish_2.jpeg differ diff --git a/examples/images/shark_1.jpeg b/examples/images/shark_1.jpeg new file mode 100644 index 00000000..e7ebb18f Binary files /dev/null and b/examples/images/shark_1.jpeg differ diff --git a/examples/images/shark_2.jpeg b/examples/images/shark_2.jpeg new file mode 100644 index 00000000..037c53dd Binary files /dev/null and b/examples/images/shark_2.jpeg differ diff --git a/examples/images/shark_3.jpeg b/examples/images/shark_3.jpeg new file mode 100644 index 00000000..640d01f2 Binary files /dev/null and b/examples/images/shark_3.jpeg differ diff --git a/examples/images/unused_cards/two_clubs.jpeg b/examples/images/unused_cards/two_clubs.jpeg new file mode 100644 index 00000000..ca819af2 Binary files /dev/null and b/examples/images/unused_cards/two_clubs.jpeg differ diff --git a/examples/multiple_classifier_integration_ex.py b/examples/multiple_classifier_integration_ex.py new file mode 100644 index 00000000..cdb55d6c --- /dev/null +++ b/examples/multiple_classifier_integration_ex.py @@ -0,0 +1,136 @@ +import pyreason as pr +import torch +import torch.nn as nn +import networkx as nx +import numpy as np +import random + + +# seed_value = 41 # legitimate, high risk +# seed_value = 42 # fraud, low risk +seed_value = 44 # fraud, high risk +random.seed(seed_value) +np.random.seed(seed_value) +torch.manual_seed(seed_value) + + +# --- Part 1: Fraud Detector Model Integration --- +# Create a dummy PyTorch model for transaction fraud detection. +fraud_model = nn.Linear(5, 2) +fraud_class_names = ["fraud", "legitimate"] +transaction_features = torch.rand(1, 5) + +# Define integration options: only probabilities > 0.5 will trigger bounds adjustment. +fraud_interface_options = pr.ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=False, + snap_value=1.0 +) + +# Wrap the fraud detection model. +fraud_detector = pr.LogicIntegratedClassifier( + fraud_model, + fraud_class_names, + model_name="fraud_detector", + interface_options=fraud_interface_options +) + +# Run the fraud detector. +logits_fraud, probabilities_fraud, fraud_facts = fraud_detector(transaction_features) # Talk about time +print("=== Fraud Detector Output ===") +print("Logits:", logits_fraud) +print("Probabilities:", probabilities_fraud) +print("\nGenerated Fraud Detector Facts:") +for fact in fraud_facts: + print(fact) + +# Context and reasoning +for fact in fraud_facts: + pr.add_fact(fact) + +# Add additional contextual facts: +# 1. The transaction is from a suspicious location. +pr.add_fact(pr.Fact("suspicious_location(AccountA)", "transaction_fact")) +# 2. Link the transaction to AccountA. +pr.add_fact(pr.Fact("transaction(AccountA)", "transaction_link")) +# 3. Register AccountA as an account. +pr.add_fact(pr.Fact("account(AccountA)", "account_fact")) + +# Define reasoning rules: +# Rule A: If the fraud detector flags fraud and the transaction is suspicious, mark AccountA for investigation. +pr.add_rule(pr.Rule("requires_investigation(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud)", "investigation_rule")) + +# --- Set up Graph and Load --- +# Build a simple graph of accounts. +G = nx.DiGraph() +G.add_node("AccountA") +G.add_node("AccountB") +G.add_node("AccountC") +# Add edges with an attribute "relationship" set to "associated". +G.add_edge("AccountA", "AccountB", associated=1) +G.add_edge("AccountB", "AccountC", associated=1) +# Load the graph into PyReason. The edge attribute "relationship" is interpreted as the predicate 'associated'. +pr.load_graph(G) + +# Define propagation rules to spread investigation and critical action flags via the "associated" relationship. +pr.add_rule(pr.Rule("requires_investigation(y) <- requires_investigation(x), associated(x,y)", "investigation_propagation_rule")) + +# --- Part 5: Run the Reasoning Engine --- +# Run the reasoning engine. +pr.settings.allow_ground_rules = True +pr.settings.atom_trace = True +interpretation = pr.reason() + +# Display reasoning results for 'requires_investigation'. +print("\n=== Reasoning Results for 'requires_investigation' ===") +trace = pr.get_rule_trace(interpretation) +print(f"RULE TRACE: \n\n{trace[0]}\n") + + +# --- Part 2: Risk Evaluator Model Integration --- +# Create another dummy PyTorch model for evaluating account risk. +risk_model = nn.Linear(5, 2) +risk_class_names = ["high_risk", "low_risk"] +risk_features = torch.rand(1, 5) + +# Define integration options for the risk evaluator. +risk_interface_options = pr.ModelInterfaceOptions( + threshold=0.5, + set_lower_bound=True, + set_upper_bound=True, + snap_value=1.0 +) + +# Wrap the risk evaluation model. +risk_evaluator = pr.LogicIntegratedClassifier( + risk_model, + risk_class_names, # document len + model_name="risk_evaluator", # binded constant + interface_options=risk_interface_options +) + +# Run the risk evaluator. +logits_risk, probabilities_risk, risk_facts = risk_evaluator(risk_features) +print("\n=== Risk Evaluator Output ===") +print("Logits:", logits_risk) +print("Probabilities:", probabilities_risk) +print("\nGenerated Risk Evaluator Facts:") +for fact in risk_facts: + print(fact) + +# --- Context and Reasoning again --- +for fact in risk_facts: + pr.add_fact(fact) + +# Rule B: If the fraud detector flags fraud and the risk evaluator flags high risk, mark AccountA for critical action. +pr.add_rule(pr.Rule("critical_action(acc) <- transaction(acc), suspicious_location(acc), fraud_detector(fraud), risk_evaluator(high_risk)", "critical_action_rule")) +pr.add_rule(pr.Rule("critical_action(y) <- critical_action(x), associated(x,y)", "critical_propagation_rule")) + +interpretation = pr.reason(again=True) + +# Display reasoning results for 'critical_action'. +print("\n=== Reasoning Results for 'critical_action' (Reasoning again) ===") +trace = pr.get_rule_trace(interpretation) +print(f"RULE TRACE: \n\n{trace[0]}\n") + diff --git a/pyreason/pyreason.py b/pyreason/pyreason.py index 31dcb139..c04d04f2 100755 --- a/pyreason/pyreason.py +++ b/pyreason/pyreason.py @@ -8,6 +8,7 @@ import warnings from typing import List, Type, Callable, Tuple, Optional +from pyreason.scripts.interpretation.interpretation_parallel import Interpretation from pyreason.scripts.utils.output import Output from pyreason.scripts.utils.filter import Filter from pyreason.scripts.program.program import Program @@ -33,6 +34,7 @@ print('torch is not installed, model integration is disabled') else: from pyreason.scripts.learning.classification.classifier import LogicIntegratedClassifier + from pyreason.scripts.learning.classification.temporal_classifier import TemporalLogicIntegratedClassifier from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions @@ -637,6 +639,38 @@ def add_annotation_function(function: Callable) -> None: __annotation_functions.append(function) +def get_logic_program() -> Optional[Program]: + """Get the logic program object + + :return: Logic program object + """ + global __program + return __program + + +def get_interpretation() -> Optional[Interpretation]: + """Get the current interpretation + + :return: Current interpretation + """ + global __program + if __program is None: + raise Exception('No interpretation found. Please run `pr.reason()` first') + return __program.interp + + +def get_time() -> int: + """Get the current time + + :return: Current time + """ + try: + i = get_interpretation() + except Exception: + return 0 + return i.time + 1 + + def reason(timesteps: int = -1, convergence_threshold: int = -1, convergence_bound_threshold: float = -1, queries: List[Query] = None, again: bool = False, restart: bool = True): """Function to start the main reasoning process. Graph and rules must already be loaded. diff --git a/pyreason/scripts/learning/classification/classifier.py b/pyreason/scripts/learning/classification/classifier.py index 06bc3ec9..6b5bdce2 100644 --- a/pyreason/scripts/learning/classification/classifier.py +++ b/pyreason/scripts/learning/classification/classifier.py @@ -4,88 +4,92 @@ import torch.nn.functional as F from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions -class LogicIntegratedClassifier(torch.nn.Module): +class LogicIntegratedClassifier(LogicIntegrationBase): """ Class to integrate a PyTorch model with PyReason. The output of the model is returned to the user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them. + Wraps any torch.nn.Module whose forward(x) returns [N, C] logits (multi-class). + Implements _infer, _postprocess, and _pred_to_facts to replace the original forward(). """ - def __init__(self, model, class_names: List[str], identifier: str = 'classifier', interface_options: ModelInterfaceOptions = None): - """ - :param model: PyTorch model to be integrated. - :param class_names: List of class names for the model output. - :param identifier: Identifier for the model, used as the constant in the facts. - :param interface_options: Options for the model interface, including threshold and snapping behavior. - """ - super(LogicIntegratedClassifier, self).__init__() - self.model = model - self.class_names = class_names - self.identifier = identifier - self.interface_options = interface_options - def get_class_facts(self, t1: int, t2: int) -> List[Fact]: + def __init__( + self, + model: torch.nn.Module, + class_names: List[str], + identifier: str = 'classifier', + interface_options: ModelInterfaceOptions = None + ): + super().__init__(model, class_names, interface_options, identifier) + + def _infer(self, x: torch.Tensor) -> torch.Tensor: + # Simply run the underlying model to get raw logits [N, C] + return self.model(x) + + def _postprocess(self, raw_output: torch.Tensor) -> torch.Tensor: """ - Return PyReason facts to create nodes for each class. Each class node will have bounds `[1,1]` with the - predicate corresponding to the model name. - :param t1: Start time for the facts - :param t2: End time for the facts - :return: List of PyReason facts + raw_output: a [N, C] logits tensor. + Apply softmax over dim=1 to get probabilities [N, C]. """ - facts = [] - for c in self.class_names: - fact = Fact(f'{c}({self.identifier})', name=f'{self.identifier}-{c}-fact', start_time=t1, end_time=t2) - facts.append(fact) - return facts + logits = raw_output + if logits.dim() != 2 or logits.size(1) != len(self.class_names): + raise ValueError( + f"Expected logits of shape [N, C] with C={len(self.class_names)}, " + f"got {tuple(logits.shape)}" + ) + return F.softmax(logits, dim=1) - def forward(self, x, t1: int = 0, t2: int = 0) -> Tuple[torch.Tensor, torch.Tensor, List[Fact]]: + def _pred_to_facts( + self, + raw_output: torch.Tensor, + probabilities: torch.Tensor, + t1: int, + t2: int + ) -> List[Fact]: """ - Forward pass of the model - :param x: Input tensor - :param t1: Start time for the facts - :param t2: End time for the facts - :return: Output tensor + Turn the [N, C] probability tensor into a flat List[Fact], + using threshold, snap_value, set_lower_bound, set_upper_bound. + Produces N * C facts. """ - output = self.model(x) - - # Convert logits to probabilities assuming a multi-class classification. - probabilities = F.softmax(output, dim=1).squeeze() opts = self.interface_options + prob = probabilities # [N, C] - # Prepare threshold tensor. - threshold = torch.tensor(opts.threshold, dtype=probabilities.dtype, device=probabilities.device) - condition = probabilities > threshold + # Build a threshold tensor + threshold = torch.tensor(opts.threshold, dtype=prob.dtype, device=prob.device) + condition = prob > threshold # [N, C] boolean + # Determine lower/upper for “true” entries if opts.snap_value is not None: - snap_value = torch.tensor(opts.snap_value, dtype=probabilities.dtype, device=probabilities.device) - # For values that pass the threshold: - lower_val = snap_value if opts.set_lower_bound else torch.tensor(0.0, dtype=probabilities.dtype, - device=probabilities.device) - upper_val = snap_value if opts.set_upper_bound else torch.tensor(1.0, dtype=probabilities.dtype, - device=probabilities.device) + snap_val = torch.tensor(opts.snap_value, dtype=prob.dtype, device=prob.device) + lower_if_true = ( + snap_val if opts.set_lower_bound else torch.tensor(0.0, dtype=prob.dtype, device=prob.device) + ) + upper_if_true = ( + snap_val if opts.set_upper_bound else torch.tensor(1.0, dtype=prob.dtype, device=prob.device) + ) else: - # If no snap_value is provided, keep original probabilities for those passing threshold. - lower_val = probabilities if opts.set_lower_bound else torch.zeros_like(probabilities) - upper_val = probabilities if opts.set_upper_bound else torch.ones_like(probabilities) + lower_if_true = prob if opts.set_lower_bound else torch.zeros_like(prob) + upper_if_true = prob if opts.set_upper_bound else torch.ones_like(prob) - # For probabilities that pass the threshold, apply the above; else, bounds are fixed to [0,1]. - lower_bounds = torch.where(condition, lower_val, torch.zeros_like(probabilities)) - upper_bounds = torch.where(condition, upper_val, torch.ones_like(probabilities)) + # Build full [N, C] lower_bounds and upper_bounds + zeros = torch.zeros_like(prob) + ones = torch.ones_like(prob) + lower_bounds = torch.where(condition, lower_if_true, zeros) # [N, C] + upper_bounds = torch.where(condition, upper_if_true, ones) # [N, C] - # Convert bounds to Python floats for fact creation. - bounds_list = [] - for i in range(len(self.class_names)): - lower = lower_bounds[i].item() - upper = upper_bounds[i].item() - bounds_list.append([lower, upper]) + N, C = prob.shape + facts: List[Fact] = [] - # Define time bounds for the facts. - facts = [] - for class_name, bounds in zip(self.class_names, bounds_list): - lower, upper = bounds - fact_str = f'{class_name}({self.identifier}) : [{lower:.3f}, {upper:.3f}]' - fact = Fact(fact_str, name=f'{self.identifier}-{class_name}-fact', start_time=t1, end_time=t2) - facts.append(fact) - return output, probabilities, facts + for i in range(N): + for j, class_name in enumerate(self.class_names): + l = lower_bounds[i, j].item() + u = upper_bounds[i, j].item() + fact_str = f"{class_name}({self.identifier}) : [{l:.3f}, {u:.3f}]" + fact_name = f"{self.identifier}-{class_name}-fact" + f = Fact(fact_str, name=fact_name, start_time=t1, end_time=t2) + facts.append(f) + return facts diff --git a/pyreason/scripts/learning/classification/logic_integration_base.py b/pyreason/scripts/learning/classification/logic_integration_base.py new file mode 100644 index 00000000..3ad8eb83 --- /dev/null +++ b/pyreason/scripts/learning/classification/logic_integration_base.py @@ -0,0 +1,125 @@ +import torch +import torch.nn.functional as F +from abc import ABC, abstractmethod +from typing import List, Tuple, Any + +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions + + +class LogicIntegrationBase(torch.nn.Module, ABC): + """ + Abstract base class for **any** model (classifier, detector, etc.) whose + outputs you want to convert into PyReason Facts with lower/upper bounds. + + Subclasses must implement: + 1. _infer(x) → raw_output + 2. _pred_to_facts(raw_output, t1, t2) → List[Fact] + + The base class handles: + - Calling `self.model(x)` + - Applying threshold, snap_value, and bound‐construction (for “probabilistic” heads), + if desired. + - Packaging everything into a final (raw_output, probs_or_filtered, facts) tuple. + """ + + def __init__( + self, + model: torch.nn.Module, + class_names: List[str], + interface_options: ModelInterfaceOptions, + identifier: str = "model" + ): + """ + :param model: Any PyTorch module. Subclasses will call it in _infer(). + :param class_names: List of “predicate” names. For a detector, this is the full label list. + :param interface_options: Contains threshold, snap_value, set_lower_bound, set_upper_bound, etc. + :param identifier: Constant to inject into each Fact (e.g. “image1”, “classifier”, “detector”). + """ + super().__init__() + self.model = model + self.class_names = class_names + self.interface_options = interface_options + self.identifier = identifier + + # (Optional) sanity‐check on class_names vs. model (each subclass can override) + self._validate_init() + + def _validate_init(self): + """ + Hook for subclasses to check, e.g. that `len(class_names)` matches + whatever the underlying model expects. + """ + pass + + def forward( + self, + x: Any, + t1: int = 0, + t2: int = 0 + ) -> Tuple[Any, Any, List[Fact]]: + """ + 1) Call `_infer(x)` to get the “raw_output.” + 2) Call `_postprocess(raw_output)` to get either “probabilities” or “filtered detections,” + depending on model‐type. + 3) Call `_pred_to_facts(raw_output, postproc, t1, t2)` to build a List[Fact]. + + Returns a 3‐tuple: + (raw_output, postproc, facts_list) + + - raw_output: whatever `model(x)` naturally returned + - postproc: a tensor of probabilities or a list of filtered boxes, etc. + - facts_list: a flat List[Fact] + """ + # 1) raw predictions + raw_output = self._infer(x) + + # 2) “postprocess” step (e.g. softmax/sigmoid + threshold for classifiers, + # or filtering by confidence for detectors) + postproc = self._postprocess(raw_output) + + # 3) Turn them into Facts + facts: List[Fact] = self._pred_to_facts(raw_output, postproc, t1, t2) + + return raw_output, postproc, facts + + @abstractmethod + def _infer(self, x: Any) -> Any: + """ + Run the underlying PyTorch model (self.model) on input x, returning + the “raw” output. For a classifier, this is a logit‐tensor. For a YOLO detector, + this might be a Results object whose `.xyxy[i]` is a [num_det×6] tensor, etc. + """ + ... + + @abstractmethod + def _postprocess(self, raw_output: Any) -> Any: + """ + Convert raw model outputs into a more convenient “postprocessed” form + that we’ll pass both to the user and into `_pred_to_facts`. + + - For a binary/multiclass classifier, apply sigmoid/softmax + threshold mask. + - For a multilabel classifier, apply sigmoid + per‐class threshold mask. + - For a detector, extract a list of (class_idx, confidence) for all detections + above threshold. + """ + ... + + @abstractmethod + def _pred_to_facts( + self, + raw_output: Any, + postproc: Any, + t1: int, + t2: int + ) -> List[Fact]: + """ + Given raw_output and postproc (see above), build a List of PyReason Fact(...) objects, + each of the form: + f"{class_name}({self.identifier}) : [lower, upper]" + + - raw_output: whatever the model returned + - postproc: tensor-of-probs or list‐of‐(class_idx,confidence) + - t1, t2: start/end timestamps + """ + ... \ No newline at end of file diff --git a/pyreason/scripts/learning/classification/temporal_classifier.py b/pyreason/scripts/learning/classification/temporal_classifier.py new file mode 100644 index 00000000..5fe01514 --- /dev/null +++ b/pyreason/scripts/learning/classification/temporal_classifier.py @@ -0,0 +1,243 @@ +import asyncio +import threading +import time +from datetime import timedelta +from datetime import datetime +from typing import List, Tuple, Optional, Union, Callable, Any + +import torch.nn +import torch.nn.functional as F + +import pyreason as pr +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions + + +class TemporalLogicIntegratedClassifier(LogicIntegrationBase): + """ + Wraps any torch.nn.Module whose forward(x) returns [N, C] logits (multi‐class), + but additionally polls in the background (either every N timesteps or every N seconds) + and injects new Facts into a PyReason logic program. + """ + def __init__( + self, + model, + class_names: List[str], + identifier: str = 'classifier', + interface_options: ModelInterfaceOptions = None, + logic_program=None, + poll_interval: Optional[Union[int, timedelta]] = None, + poll_condition: Optional[str] = None, + input_fn: Optional[Callable[[], Any]] = None, + ): + """ + :param model: PyTorch model to be integrated. + :param class_names: List of class names for the model output. + :param identifier: Identifier for the model, used as the constant in the facts. + :param interface_options: Options for the model interface, including threshold and snapping behavior. + :param logic_program: PyReason logic program + :param poll_interval: How often to poll the model, either as: + - an integer number of PyReason timesteps or + - a `datetime.timedelta` object representing wall-clock time. + If `None`, polling is disabled. + :param poll_condition: The name of the predicate attached to the model that must be true to trigger a poll. + If `None`, the model will be polled every `poll_interval` time steps/seconds. + :param input_fn: Function to call to get the input to the model. This function should return a tensor. + """ + super().__init__(model, class_names, interface_options, identifier) + self.model = model + self.class_names = class_names + self.identifier = identifier + self.interface_options = interface_options + self.logic_program = logic_program + self.poll_interval = poll_interval + self.poll_condition = poll_condition + self.input_fn = input_fn + + # normalize poll_interval + if isinstance(poll_interval, int): + self.poll_interval: Union[int, timedelta, None] = poll_interval + else: + self.poll_interval = poll_interval + + # start the async polling task if configured + if self.poll_interval is not None and self.input_fn is not None: + # this schedules the background task + # self._poller_task = asyncio.create_task(self._poll_loop()) + # kick off the background thread + t = threading.Thread(target=self._poll_loop, daemon=True) + t.start() + + def _get_current_timestep(self): + """ + Get the current timestep from the PyReason logic program. + :return: Current timestep + """ + if self.logic_program is not None and self.logic_program.interp is not None: + interp = self.logic_program.interp + t = interp.time + return t + elif pr.get_logic_program() is not None and pr.get_logic_program().interp is not None: + self.logic_program = pr.get_logic_program() + interp = self.logic_program.interp + t = interp.time + return t + else: + # raise ValueError("No PyReason logic program provided.") + return -1 + + def _poll_loop(self) -> None: + """ + Background async loop that polls every self.poll_interval. + """ + # if self.logic_program is None: + # raise ValueError("No logic program to add facts into.") + + # check if we have a logic program yet or not + while True: + current_time = self._get_current_timestep() + # print("here") + if current_time != -1: + print("current time", current_time) + # determine mode + if isinstance(self.poll_interval, timedelta): + interval_secs = self.poll_interval.total_seconds() + while True: + print("in loop") + time.sleep(interval_secs) + current_time = self._get_current_timestep() + t1 = current_time + 1 + t2 = t1 + + if self.poll_condition: + print(f"{self.poll_condition}({self.identifier})") + print(self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})"))) + if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")): + continue + + x = self.input_fn() + _, _, facts = self.forward(x, t1, t2) + for f in facts: + print(f) + pr.add_fact(f) + + # run the reasoning + pr.reason(again=True, restart=False) + print("reasoning done") + trace = pr.get_rule_trace(self.logic_program.interp) + print(trace[0]) + + else: + step_interval = self.poll_interval + last_step = current_time + 1 + while True: + # wait until enough timesteps have passed + while self._get_current_timestep() - last_step < step_interval: + time.sleep(0.01) + current = self._get_current_timestep() + t1, t2 = current, current + last_step = current + + if self.poll_condition: + if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")): + continue + + x = self.input_fn() + _, _, facts = self.forward(x, t1, t2) + for f in facts: + pr.add_fact(f) + + # run the reasoning + pr.reason(again=True, restart=False) + print("reasoning done") + trace = pr.get_rule_trace(self.logic_program.interp) + print(trace[0]) + + # # run the reasoning + # pr.reason(again=True, restart=False) + # print("reasoning done") + # trace = pr.get_rule_trace(interpretation) + # print(trace[0]) + + def get_class_facts(self, t1: int, t2: int) -> List[Fact]: + """ + Return PyReason facts to create nodes for each class. Each class node will have bounds `[1,1]` with the + predicate corresponding to the model name. + :param t1: Start time for the facts + :param t2: End time for the facts + :return: List of PyReason facts + """ + facts = [] + for c in self.class_names: + fact = Fact(f'{c}({self.identifier})', name=f'{self.identifier}-{c}-fact', start_time=t1, end_time=t2) + facts.append(fact) + return facts + + def _infer(self, x: torch.Tensor) -> torch.Tensor: + """ + Run the underlying model to get raw logits [N, C]. + """ + return self.model(x) + + def _postprocess(self, raw_output: torch.Tensor) -> torch.Tensor: + """ + raw_output should be a [N, C] logits tensor. Assert C == len(class_names), + then apply softmax over dim=1 → [N, C] probabilities. + """ + logits = raw_output + if logits.dim() != 2 or logits.size(1) != len(self.class_names): + raise ValueError( + f"Expected logits of shape [N, C] with C={len(self.class_names)}, got {tuple(logits.shape)}" + ) + return F.softmax(logits, dim=1) + + def _pred_to_facts( + self, + raw_output: torch.Tensor, + probabilities: torch.Tensor, + t1: int, + t2: int + ) -> List[Fact]: + """ + Given a [N, C] probability tensor, build a flat List[Fact], + using threshold, snap_value, set_lower_bound, set_upper_bound. + Returns N * C facts. + """ + opts = self.interface_options + prob = probabilities # [N, C] + + # Build a threshold tensor + threshold = torch.tensor(opts.threshold, dtype=prob.dtype, device=prob.device) + condition = prob > threshold # [N, C] boolean mask + + # Determine lower/upper for “true” entries + if opts.snap_value is not None: + snap_val = torch.tensor(opts.snap_value, dtype=prob.dtype, device=prob.device) + lower_if_true = (snap_val if opts.set_lower_bound + else torch.tensor(0.0, dtype=prob.dtype, device=prob.device)) + upper_if_true = (snap_val if opts.set_upper_bound + else torch.tensor(1.0, dtype=prob.dtype, device=prob.device)) + else: + lower_if_true = prob if opts.set_lower_bound else torch.zeros_like(prob) + upper_if_true = prob if opts.set_upper_bound else torch.ones_like(prob) + + zeros = torch.zeros_like(prob) + ones = torch.ones_like(prob) + lower_bounds = torch.where(condition, lower_if_true, zeros) # [N, C] + upper_bounds = torch.where(condition, upper_if_true, ones) # [N, C] + + N, C = prob.shape + all_facts: List[Fact] = [] + + for i in range(N): + for j, class_name in enumerate(self.class_names): + lower_val = lower_bounds[i, j].item() + upper_val = upper_bounds[i, j].item() + fact_str = f"{class_name}({self.identifier}) : [{lower_val:.3f}, {upper_val:.3f}]" + fact_name = f"{self.identifier}-{class_name}-fact" + f = Fact(fact_str, name=fact_name, start_time=t1, end_time=t2) + all_facts.append(f) + + return all_facts + diff --git a/pyreason/scripts/learning/classification/yolo_classifier.py b/pyreason/scripts/learning/classification/yolo_classifier.py new file mode 100644 index 00000000..1e784d67 --- /dev/null +++ b/pyreason/scripts/learning/classification/yolo_classifier.py @@ -0,0 +1,213 @@ +from datetime import timedelta +from pathlib import Path +import random +import threading +import time + +import cv2 +import torch +import pyreason as pr +from pyreason.scripts.facts.fact import Fact +from pyreason.scripts.learning.classification.logic_integration_base import LogicIntegrationBase +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions + +from typing import List, Tuple, Optional, Union, Callable, Any + +class YoloLogicIntegratedTemporalClassifier(LogicIntegrationBase): + """ + Class to integrate a YOLO model with PyReason. The output of the model is returned to the + user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them. + Wraps a YOLO model whose forward(x) returns bounding boxes with class probabilities. + Implements _infer, _postprocess, and _pred_to_facts to replace the original forward(). + """ + + def __init__( + self, + model, + class_names: List[str], + identifier: str = 'yolo_classifier', + interface_options: ModelInterfaceOptions = None, + poll_interval: Optional[Union[int, timedelta]] = None, + poll_condition: Optional[str] = None, + input_fn: Optional[Callable[[], Any]] = None + ): + """ + :param model: PyTorch model to be integrated. + :param class_names: List of class names for the model output. + :param identifier: Identifier for the model, used as the constant in the facts. + :param interface_options: Options for the model interface, including threshold and snapping behavior. + :param logic_program: PyReason logic program + :param poll_interval: How often to poll the model, either as: + - an integer number of PyReason timesteps or + - a `datetime.timedelta` object representing wall-clock time. + If `None`, polling is disabled. + :param poll_condition: The name of the predicate attached to the model that must be true to trigger a poll. + If `None`, the model will be polled every `poll_interval` time steps/seconds. + :param input_fn: Function to call to get the input to the model. This function should return a tensor. + """ + super().__init__(model, class_names, interface_options, identifier) + self.model = model + self.class_names = class_names + self.identifier = identifier + self.interface_options = interface_options + self.poll_interval = poll_interval + self.poll_condition = poll_condition + self.input_fn = input_fn + self.logic_program = None # Get the current logic program + + # normalize poll_interval + if isinstance(poll_interval, int): + self.poll_interval: Union[int, timedelta, None] = poll_interval + else: + self.poll_interval = poll_interval + + # start the async polling task if configured + if self.poll_interval is not None and self.input_fn is not None: + # this schedules the background task + # self._poller_task = asyncio.create_task(self._poll_loop()) + # kick off the background thread + t = threading.Thread(target=self._poll_loop, daemon=True) + t.start() + + def _infer(self, x: Any) -> Any: + print("Running YOLO model inference...") + # resized_image = cv2.resize(image, (640, 640)) # Direct resize + # normalized_image = resized_image / 255.0 # Normalize + result_predict = self.model.predict(source = x, imgsz=(640), conf=0.1) #the default image size + #print("Predicted output:", result_predict) + return result_predict + + def _postprocess(self, raw_output: Any) -> Any: + """ + Process the raw output from the YOLO model to extract bounding boxes and class probabilities. + """ + result = raw_output[0] # Get the first result from the prediction + box = result.boxes[0] # Get the first bounding box from the result + label_id = int(box.cls) + confidence = float(box.conf) + label_name = result.names[label_id] # Get the label name from the names dictionary + print(f"Predicted label: {label_name}, Confidence: {confidence:.2f}") + return [label_name, confidence] + + def _pred_to_facts( + self, + raw_output: Any, + result: List, + confidence: float, + + t1: int = None, + t2: int = None + ) -> List[Fact]: + """ + Given a [N, C] probability tensor, build a flat List[Fact], + using threshold, snap_value, set_lower_bound, set_upper_bound. + Returns N * C facts. + """ + opts = self.interface_options + label = result[0] + confidence = result[1] + # Build a threshold tensor + threshold = torch.tensor(opts.threshold) + condition = confidence > threshold # [N, C] boolean mask + + # Determine lower/upper for “true” entries + if opts.snap_value is not None: + snap_val = opts.snap_value + lower_if_true = (snap_val if opts.set_lower_bound + else 0) + upper_if_true = (snap_val if opts.set_upper_bound + else 1.0) + else: + lower_if_true = confidence if opts.set_lower_bound else 0 + upper_if_true = confidence if opts.set_upper_bound else 1.0 + + all_facts: List[Fact] = [] + + fact_str = f"_{label}({self.identifier}) : [{lower_if_true:.3f}, {upper_if_true:.3f}]" + fact_name = f"{self.identifier}-{label}-fact" + f = Fact(fact_str, name=fact_name, start_time=0, end_time=0) + all_facts.append(f) + + return all_facts + + def _get_current_timestep(self): + """ + Get the current timestep from the PyReason logic program. + :return: Current timestep + """ + if self.logic_program is not None and self.logic_program.interp is not None: + interp = self.logic_program.interp + t = interp.time + return t + elif pr.get_logic_program() is not None and pr.get_logic_program().interp is not None: + self.logic_program = pr.get_logic_program() + interp = self.logic_program.interp + t = interp.time + return t + else: + # raise ValueError("No PyReason logic program provided.") + return -1 + + def _poll_loop(self) -> None: + """ + Background async loop that polls every self.poll_interval. + """ + # if self.logic_program is None: + # raise ValueError("No logic program to add facts into.") + + # check if we have a logic program yet or not + while True: + current_time = self._get_current_timestep() + if current_time != -1: + # determine mode + if isinstance(self.poll_interval, timedelta): + interval_secs = self.poll_interval.total_seconds() + while True: + time.sleep(interval_secs) + current_time = self._get_current_timestep() + t1 = current_time + 1 + t2 = t1 + + if self.poll_condition: + print(f"{self.poll_condition}({self.identifier})") + print(self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})"))) + if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")): + print(f"Condition {self.poll_condition} not met, skipping poll.") + continue + print("Condition met, polling model...") + x = self.input_fn() + _, _, facts = self.forward(x, t1, t2) + for f in facts: + print(f) + pr.add_fact(f) + + # run the reasoning + pr.reason(again=True, restart=True) + print("reasoning done") + trace = pr.get_rule_trace(self.logic_program.interp) + print(trace[0]) + + else: + step_interval = self.poll_interval + last_step = current_time + 1 + while True: + # wait until enough timesteps have passed + while self._get_current_timestep() - last_step < step_interval: + time.sleep(0.01) + current = self._get_current_timestep() + t1, t2 = current, current + last_step = current + + if self.poll_condition: + if not self.logic_program.interp.query(pr.Query(f"{self.poll_condition}({self.identifier})")): + continue + + x = self.input_fn() + _, _, facts = self.forward(x, t1, t2) + for f in facts: + pr.add_fact(f) + + # run the reasoning + pr.reason(again=True, restart=False) + trace = pr.get_rule_trace(self.logic_program.interp) + print(trace[0]) diff --git a/pyreason/train56/F1_curve.png b/pyreason/train56/F1_curve.png new file mode 100644 index 00000000..52bb898c Binary files /dev/null and b/pyreason/train56/F1_curve.png differ diff --git a/pyreason/train56/PR_curve.png b/pyreason/train56/PR_curve.png new file mode 100644 index 00000000..88947d0b Binary files /dev/null and b/pyreason/train56/PR_curve.png differ diff --git a/pyreason/train56/P_curve.png b/pyreason/train56/P_curve.png new file mode 100644 index 00000000..227b6af3 Binary files /dev/null and b/pyreason/train56/P_curve.png differ diff --git a/pyreason/train56/R_curve.png b/pyreason/train56/R_curve.png new file mode 100644 index 00000000..7f85a0a2 Binary files /dev/null and b/pyreason/train56/R_curve.png differ diff --git a/pyreason/train56/args.yaml b/pyreason/train56/args.yaml new file mode 100644 index 00000000..2cc2cddb --- /dev/null +++ b/pyreason/train56/args.yaml @@ -0,0 +1,105 @@ +task: detect +mode: train +model: yolov5nu.pt +data: /Users/coltonpayne/datasets/playing_card_dataset/kaggle_data.yaml +epochs: 10 +time: null +patience: 100 +batch: -1 +imgsz: 640 +save: true +save_period: -1 +cache: false +device: cpu +workers: 0 +project: null +name: train56 +exist_ok: false +pretrained: true +optimizer: auto +verbose: true +seed: 0 +deterministic: true +single_cls: false +rect: false +cos_lr: false +close_mosaic: 10 +resume: false +amp: true +fraction: 1.0 +profile: false +freeze: null +multi_scale: false +overlap_mask: true +mask_ratio: 4 +dropout: 0.0 +val: true +split: val +save_json: false +conf: null +iou: 0.7 +max_det: 300 +half: false +dnn: false +plots: true +source: null +vid_stride: 1 +stream_buffer: false +visualize: false +augment: false +agnostic_nms: false +classes: null +retina_masks: false +embed: null +show: false +save_frames: false +save_txt: false +save_conf: false +save_crop: false +show_labels: true +show_conf: true +show_boxes: true +line_width: null +format: torchscript +keras: false +optimize: false +int8: false +dynamic: false +simplify: true +opset: null +workspace: null +nms: false +lr0: 0.01 +lrf: 0.01 +momentum: 0.937 +weight_decay: 0.0005 +warmup_epochs: 3.0 +warmup_momentum: 0.8 +warmup_bias_lr: 0.0 +box: 7.5 +cls: 0.5 +dfl: 1.5 +pose: 12.0 +kobj: 1.0 +nbs: 64 +hsv_h: 0.015 +hsv_s: 0.7 +hsv_v: 0.4 +degrees: 0.0 +translate: 0.1 +scale: 0.5 +shear: 0.0 +perspective: 0.0 +flipud: 0.0 +fliplr: 0.5 +bgr: 0.0 +mosaic: 1.0 +mixup: 0.0 +cutmix: 0.0 +copy_paste: 0.0 +copy_paste_mode: flip +auto_augment: randaugment +erasing: 0.4 +cfg: null +tracker: botsort.yaml +save_dir: /Users/coltonpayne/pyreason/runs/detect/train56 diff --git a/pyreason/train56/confusion_matrix.png b/pyreason/train56/confusion_matrix.png new file mode 100644 index 00000000..e57b21f7 Binary files /dev/null and b/pyreason/train56/confusion_matrix.png differ diff --git a/pyreason/train56/confusion_matrix_normalized.png b/pyreason/train56/confusion_matrix_normalized.png new file mode 100644 index 00000000..4903f07f Binary files /dev/null and b/pyreason/train56/confusion_matrix_normalized.png differ diff --git a/pyreason/train56/labels.jpg b/pyreason/train56/labels.jpg new file mode 100644 index 00000000..13f4c716 Binary files /dev/null and b/pyreason/train56/labels.jpg differ diff --git a/pyreason/train56/labels_correlogram.jpg b/pyreason/train56/labels_correlogram.jpg new file mode 100644 index 00000000..e2beccc0 Binary files /dev/null and b/pyreason/train56/labels_correlogram.jpg differ diff --git a/pyreason/train56/results.csv b/pyreason/train56/results.csv new file mode 100644 index 00000000..c02e7864 --- /dev/null +++ b/pyreason/train56/results.csv @@ -0,0 +1,11 @@ +epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2 +1,5063.78,0.92856,3.11485,0.91922,0.1199,0.33001,0.1183,0.09564,0.75675,2.81461,0.87198,5.95985e-05,5.95985e-05,5.95985e-05 +2,10101.1,0.83387,2.72267,0.89052,0.22664,0.4928,0.25171,0.20827,0.69211,2.29976,0.85697,0.000107458,0.000107458,0.000107458 +3,15123.1,0.76501,2.29966,0.87401,0.33506,0.57945,0.3979,0.33987,0.61122,1.91816,0.83818,0.000143503,0.000143503,0.000143503 +4,20158.7,0.71159,1.96178,0.86,0.46015,0.68805,0.55268,0.47708,0.58224,1.58141,0.83269,0.000125837,0.000125837,0.000125837 +5,25164.5,0.67845,1.71499,0.85183,0.55518,0.72521,0.65981,0.57435,0.56744,1.35155,0.829,0.000108116,0.000108116,0.000108116 +6,30166.6,0.65638,1.53425,0.84756,0.62735,0.77086,0.72308,0.63153,0.54953,1.21587,0.82626,9.0395e-05,9.0395e-05,9.0395e-05 +7,35166.1,0.63751,1.40779,0.84108,0.69098,0.78398,0.7859,0.69156,0.53247,1.09352,0.82043,7.2674e-05,7.2674e-05,7.2674e-05 +8,42297.8,0.62295,1.30591,0.83881,0.71594,0.81527,0.81449,0.71846,0.52529,1.01363,0.81941,5.4953e-05,5.4953e-05,5.4953e-05 +9,47297.2,0.61573,1.24373,0.83719,0.7187,0.82442,0.82311,0.72721,0.51808,0.9842,0.8177,3.7232e-05,3.7232e-05,3.7232e-05 +10,52325.3,0.60752,1.2017,0.836,0.73426,0.82988,0.83733,0.74203,0.5146,0.94778,0.81682,1.9511e-05,1.9511e-05,1.9511e-05 diff --git a/pyreason/train56/results.png b/pyreason/train56/results.png new file mode 100644 index 00000000..d42ebc85 Binary files /dev/null and b/pyreason/train56/results.png differ diff --git a/pyreason/train56/train_batch0.jpg b/pyreason/train56/train_batch0.jpg new file mode 100644 index 00000000..aab8f47c Binary files /dev/null and b/pyreason/train56/train_batch0.jpg differ diff --git a/pyreason/train56/train_batch1.jpg b/pyreason/train56/train_batch1.jpg new file mode 100644 index 00000000..6a687111 Binary files /dev/null and b/pyreason/train56/train_batch1.jpg differ diff --git a/pyreason/train56/train_batch2.jpg b/pyreason/train56/train_batch2.jpg new file mode 100644 index 00000000..a3038ecb Binary files /dev/null and b/pyreason/train56/train_batch2.jpg differ diff --git a/pyreason/train56/val_batch0_labels.jpg b/pyreason/train56/val_batch0_labels.jpg new file mode 100644 index 00000000..5cdf9824 Binary files /dev/null and b/pyreason/train56/val_batch0_labels.jpg differ diff --git a/pyreason/train56/val_batch0_pred.jpg b/pyreason/train56/val_batch0_pred.jpg new file mode 100644 index 00000000..a47a6abb Binary files /dev/null and b/pyreason/train56/val_batch0_pred.jpg differ diff --git a/pyreason/train56/val_batch1_labels.jpg b/pyreason/train56/val_batch1_labels.jpg new file mode 100644 index 00000000..2002c490 Binary files /dev/null and b/pyreason/train56/val_batch1_labels.jpg differ diff --git a/pyreason/train56/val_batch1_pred.jpg b/pyreason/train56/val_batch1_pred.jpg new file mode 100644 index 00000000..6dc195af Binary files /dev/null and b/pyreason/train56/val_batch1_pred.jpg differ diff --git a/pyreason/train56/val_batch2_labels.jpg b/pyreason/train56/val_batch2_labels.jpg new file mode 100644 index 00000000..2c5d30c6 Binary files /dev/null and b/pyreason/train56/val_batch2_labels.jpg differ diff --git a/pyreason/train56/val_batch2_pred.jpg b/pyreason/train56/val_batch2_pred.jpg new file mode 100644 index 00000000..2ab61187 Binary files /dev/null and b/pyreason/train56/val_batch2_pred.jpg differ diff --git a/pyreason/train56/weights/best.pt b/pyreason/train56/weights/best.pt new file mode 100644 index 00000000..108f010b Binary files /dev/null and b/pyreason/train56/weights/best.pt differ diff --git a/pyreason/train56/weights/last.pt b/pyreason/train56/weights/last.pt new file mode 100644 index 00000000..a85a9788 Binary files /dev/null and b/pyreason/train56/weights/last.pt differ diff --git a/tests/math_history.graphml b/tests/math_history.graphml new file mode 100644 index 00000000..706aa27f --- /dev/null +++ b/tests/math_history.graphml @@ -0,0 +1,132 @@ + + + + + + + + + + Euclid + -300 + -275 + Elements, Geometry + Greek + + + Archimedes + -287 + -212 + Area, Volume, Pi + Greek + + + Diophantus + 201 + 285 + Arithmetica, Algebra + Greek + + + Aryabhata + 476 + 550 + Sine tables, Astronomy + + + Brahmagupta + 598 + 668 + Zero, Negative numbers + + + Al-Khwarizmi + 780 + 850 + Algebra, Algorithms + + + Fibonacci + 1170 + 1250 + Fibonacci sequence, Hindu-Arabic numerals + Italian + + + Isaac Newton + 1643 + 1727 + Calculus, Laws of motion + + + Gottfried Leibniz + 1646 + 1716 + Calculus, Notation + + + Leonhard Euler + 1707 + 1783 + Analysis, Number theory + + + Carl Friedrich Gauss + 1777 + 1855 + Number theory, Statistics + + + Alan Turing + 1912 + 1954 + Computer science, Turing machine + + + Pythagoreas + -570 + -495 + Geometry, Pythagorean Theorem + + + + Preceded + + + Preceded + + + Preceded + + + Preceded + + + Preceded + + + Preceded + + + Preceded + + + Preceded + + + Preceded + + + Preceded + + + Preceded + + + Preceded + + + \ No newline at end of file diff --git a/tests/test.py b/tests/test.py new file mode 100644 index 00000000..76fb42f3 --- /dev/null +++ b/tests/test.py @@ -0,0 +1,83 @@ +import pyreason as pr +import torch +import torch.nn as nn +import networkx as nx + +# --- Part 1: Fraud Detector Model Integration --- + +# Create a dummy PyTorch model for transaction fraud detection. +# Each transaction is represented by 5 features and is classified into "fraud" or "legitimate". +model = nn.Linear(5, 2) +class_names = ["fraud", "legitimate"] + +# Create a dummy transaction feature vector. +transaction_features = torch.rand(1, 5) + +# Define integration options. +# Only probabilities above 0.5 are considered for adjustment. +interface_options = pr.ModelInterfaceOptions( + threshold=0.4, # Only process probabilities above 0.5 + set_lower_bound=True, # For high confidence, adjust the lower bound. + set_upper_bound=False, # Keep the upper bound unchanged. + snap_value=1.0 # Use 1.0 as the snap value. +) + +# Wrap the model using LogicIntegratedClassifier +fraud_detector = pr.LogicIntegratedClassifier( + model, + class_names, + model_name="fraud_detector", + interface_options=interface_options +) + +# Run the model to obtain logits, probabilities, and generated PyReason facts. +logits, probabilities, classifier_facts = fraud_detector(transaction_features) + +print("=== Fraud Detector Output ===") +print("Logits:", logits) +print("Probabilities:", probabilities) +print("\nGenerated Classifier Facts:") +for fact in classifier_facts: + print(fact) + +# Add the classifier-generated facts. +for fact in classifier_facts: + pr.add_fact(fact) + +# --- Part 2: Create and Load a Networkx Graph representing an account knowledge base --- + +# Create a networkx graph representing a network of accounts. +G = nx.DiGraph() +# Add account nodes. +G.add_node("AccountA", account=1) +G.add_node("AccountB", account=1) +G.add_node("AccountC", account=1) +# Add edges with an attribute "relationship" set to "associated". +G.add_edge("AccountA", "AccountB", associated=1) +G.add_edge("AccountB", "AccountC", associated=1) +pr.load_graph(G) + +# --- Part 3: Set Up Context and Reasoning Environment --- + +# Add additional contextual information: +# 1. A fact indicating the transaction comes from a suspicious location. This could come from a separate fraud detection system. +pr.add_fact(pr.Fact("suspicious_location(AccountA)", "transaction_fact")) + +# Define a rule: if the fraud detector flags a transaction as fraud and the transaction info is suspicious, +# then mark the associated account (AccountA) as requiring investigation. +pr.add_rule(pr.Rule("requires_investigation(acc) <- account(acc), fraud_detector(fraud), suspicious_location(acc)", "investigation_rule")) + +# Define a propagation rule: +# If an account requires investigation and is connected (via the "associated" relationship) to another account, +# then the connected account is also flagged for investigation. +pr.add_rule(pr.Rule("requires_investigation(y) <- requires_investigation(x), associated(x,y)", "propagation_rule")) + +# --- Part 4: Run the Reasoning Engine --- + +# Run the reasoning engine for 3 timesteps to allow the investigation flag to propagate through the network. +pr.settings.allow_ground_rules = True +pr.settings.atom_trace = True +interpretation = pr.reason() + +trace = pr.get_rule_trace(interpretation) +print(f"RULE TRACE: \n\n{trace[0]}\n")