|
| 1 | +Integrating PyReason with Machine Learning |
| 2 | +=========================== |
| 3 | + |
| 4 | +PyReason can be integrated with machine learning models by incorporating the predictions from the machine learning model |
| 5 | +as facts in the graph and reasoning over them with logical rules. This allows users to combine the strengths of machine learning |
| 6 | +models with the interpretability and reasoning capabilities of PyReason. |
| 7 | + |
| 8 | +Classifier Integration Example |
| 9 | +----------------------------- |
| 10 | +.. note:: |
| 11 | + Find the full, executable code `here <https://github.com/lab-v2/pyreason/blob/main/examples/classifier_integration_ex.py>`_ |
| 12 | + |
| 13 | +In this section, we will outline how to perform ML integration using a simple classification example. We assume |
| 14 | +that we have a fraud detection model that predicts whether a transaction is fraudulent or not. We will use the predictions |
| 15 | +to reason over a knowledge base of account information to identify potential fraudulent activities. For this example, we |
| 16 | +use an untrained linear model just to demonstrate, but in practice, this can be replaced by any PyTorch classification model. |
| 17 | + |
| 18 | +We start by defining our classifier model. |
| 19 | + |
| 20 | +.. code-block:: python |
| 21 | +
|
| 22 | + import torch |
| 23 | + import torch.nn as nn |
| 24 | +
|
| 25 | +
|
| 26 | + model = nn.Linear(5, 2) |
| 27 | + class_names = ["fraud", "legitimate"] |
| 28 | +
|
| 29 | +Next, we define how we want to incorporate the predictions from the model into the graph. Since the model outputs a probability |
| 30 | +over the classes, we can specify how we integrate this probability into the graph. There are a few options the user can define |
| 31 | +using a ``ModelInterfaceOptions`` object: |
| 32 | + |
| 33 | +1. ``threshold`` **(float)**: The threshold beyond which the prediction is incorporated as a fact in the graph. If the probability |
| 34 | + of the class is lower than the threshold, no information is added to the graph. Defaults to 0.5. |
| 35 | +2. ``set_lower_bound`` **(bool)**: If True, the lower bound of the probability is set as the fact in the graph. |
| 36 | + if False the lower bound will be 0. Defaults to True. |
| 37 | +3. ``set_upper_bound`` **(bool)**: If True, the upper bound of the probability is set as the fact in the graph. |
| 38 | + if False, the upper bound will be 1. Defaults to True. |
| 39 | +4. ``snap_value`` **(float)**: If set, all the probabilities that crossed the threshold are snapped to this value. Defaults to 1.0. |
| 40 | + The upper/lower bounds are set to this value according to the ``set_lower_bound`` and ``set_upper_bound`` options. |
| 41 | + |
| 42 | +In our binary classification model, we want predictions that cross the threshold of ``0.5`` to be added to the graph. |
| 43 | +For this example we will use a ``snap_value`` of ``1.0`` and set the lower bound of the probability as the fact in the graph. |
| 44 | +Therefore any prediction with a probability greater than ``0.5`` will be added to the graph as a fact with bounds of ``[1.0, 1.0]``. |
| 45 | + |
| 46 | +.. code-block:: python |
| 47 | +
|
| 48 | + interface_options = pr.ModelInterfaceOptions( |
| 49 | + threshold=0.5, # Only process probabilities above 0.5 |
| 50 | + set_lower_bound=True, # Modify the lower bound. |
| 51 | + set_upper_bound=False, # Keep the upper bound unchanged at 1.0. |
| 52 | + snap_value=1.0 # Use 1.0 as the snap value. |
| 53 | + ) |
| 54 | +
|
| 55 | +
|
| 56 | +Next, we create a ``LogicIntegratedClassifier`` object that helps us integrate the predictions from the model into the graph. |
| 57 | + |
| 58 | +.. code-block:: python |
| 59 | +
|
| 60 | + fraud_detector = pr.LogicIntegratedClassifier( |
| 61 | + model, |
| 62 | + class_names, |
| 63 | + model_name="fraud_detector", |
| 64 | + interface_options=interface_options |
| 65 | + ) |
| 66 | +
|
| 67 | +
|
| 68 | +To run the model, we perform the same steps as we would with a regular PyTorch model. In this example we use a dummy input. |
| 69 | +This gives us a list of facts that can then be added to PyReason. |
| 70 | + |
| 71 | +.. code-block:: python |
| 72 | +
|
| 73 | + transaction_features = torch.rand(1, 5) |
| 74 | +
|
| 75 | + # Get the prediction from the model |
| 76 | + logits, probabilities, classifier_facts = fraud_detector(transaction_features) |
| 77 | +
|
| 78 | +We now add the facts to PyReason as normal. |
| 79 | + |
| 80 | +.. code-block:: python |
| 81 | +
|
| 82 | + # Add the classifier-generated facts. |
| 83 | + for fact in classifier_facts: |
| 84 | + pr.add_fact(fact) |
| 85 | +
|
| 86 | +
|
| 87 | +
|
| 88 | +Next, we define a knowledge graph that contains information about accounts and its relationships. we also define some context |
| 89 | +about the transaction and rules that we want to reason over with the classifier predictions. |
| 90 | + |
| 91 | +.. code-block:: python |
| 92 | + # Create a networkx graph representing a network of accounts. |
| 93 | + G = nx.DiGraph() |
| 94 | + # Add account nodes. |
| 95 | + G.add_node("AccountA", account=1) |
| 96 | + G.add_node("AccountB", account=1) |
| 97 | + G.add_node("AccountC", account=1) |
| 98 | +
|
| 99 | + # Add edges with an attribute "associated". |
| 100 | + G.add_edge("AccountA", "AccountB", associated=1) |
| 101 | + G.add_edge("AccountB", "AccountC", associated=1) |
| 102 | + pr.load_graph(G) |
| 103 | +
|
| 104 | + # Add additional contextual information: |
| 105 | + # 1. A fact indicating the transaction comes from a suspicious location. This could come from a separate fraud detection system. |
| 106 | + pr.add_fact(pr.Fact("suspicious_location(AccountA)", "transaction_fact")) |
| 107 | +
|
| 108 | + # Define a rule: if the fraud detector flags a transaction as fraud and the transaction info is suspicious, |
| 109 | + # then mark the associated account (AccountA) as requiring investigation. |
| 110 | + pr.add_rule(pr.Rule("requires_investigation(acc) <- account(acc), fraud_detector(fraud), suspicious_location(acc)", "investigation_rule")) |
| 111 | +
|
| 112 | + # Define a propagation rule: |
| 113 | + # If an account requires investigation and is connected (via the "associated" relationship) to another account, |
| 114 | + # then the connected account is also flagged for investigation. |
| 115 | + pr.add_rule(pr.Rule("requires_investigation(y) <- requires_investigation(x), associated(x,y)", "propagation_rule")) |
| 116 | +
|
| 117 | +
|
| 118 | +Finally, we run the reasoning process and print the output. |
| 119 | + |
| 120 | +.. code-block:: python |
| 121 | +
|
| 122 | + # Run the reasoning engine to allow the investigation flag to propagate through the network. |
| 123 | + pr.settings.allow_ground_rules = True # The ground rules allow us to use the classifier prediction facts |
| 124 | + pr.settings.atom_trace = True |
| 125 | + interpretation = pr.reason() |
| 126 | +
|
| 127 | + trace = pr.get_rule_trace(interpretation) |
| 128 | + print(f"RULE TRACE: \n\n{trace[0]}\n") |
| 129 | +
|
| 130 | +
|
| 131 | +
|
| 132 | +This simple example demonstrates the integration of a machine learning model with PyReason. In practice more complex models |
| 133 | +can be used, along with larger and more complex knowledge graphs. |
0 commit comments