|
| 1 | +from typing import List |
| 2 | + |
| 3 | +import torch.nn |
| 4 | +import torch.nn.functional as F |
| 5 | + |
| 6 | +from pyreason.scripts.facts.fact import Fact |
| 7 | +from pyreason.scripts.learning.utils.model_interface import ModelInterfaceOptions |
| 8 | + |
| 9 | + |
| 10 | +class LogicIntegratedClassifier(torch.nn.Module): |
| 11 | + """ |
| 12 | + Class to integrate a PyTorch model with PyReason. The output of the model is returned to the |
| 13 | + user in the form of PyReason facts. The user can then add these facts to the logic program and reason using them. |
| 14 | + """ |
| 15 | + def __init__(self, model, class_names: List[str], model_name: str = 'classifier', interface_modes: ModelInterfaceOptions = None): |
| 16 | + """ |
| 17 | + :param model: |
| 18 | + :param class_names: |
| 19 | + """ |
| 20 | + super(LogicIntegratedClassifier, self).__init__() |
| 21 | + self.model = model |
| 22 | + self.class_names = class_names |
| 23 | + self.model_name = model_name |
| 24 | + self.interface_modes = interface_modes |
| 25 | + |
| 26 | + def get_class_facts(self, t1: int, t2: int) -> List[Fact]: |
| 27 | + """ |
| 28 | + Return PyReason facts to create nodes for each class. Each class node will have bounds `[1,1]` with the |
| 29 | + predicate corresponding to the model name. |
| 30 | + :param t1: Start time for the facts |
| 31 | + :param t2: End time for the facts |
| 32 | + :return: List of PyReason facts |
| 33 | + """ |
| 34 | + facts = [] |
| 35 | + for c in self.class_names: |
| 36 | + fact = Fact(f'{self.model_name}({c})', name=f'{self.model_name}-{c}-fact', start_time=t1, end_time=t2) |
| 37 | + facts.append(fact) |
| 38 | + return facts |
| 39 | + |
| 40 | + def forward(self, x, t1: int, t2: int): |
| 41 | + """ |
| 42 | + Forward pass of the model |
| 43 | + :param x: Input tensor |
| 44 | + :param t1: Start time for the facts |
| 45 | + :param t2: End time for the facts |
| 46 | + :return: Output tensor |
| 47 | + """ |
| 48 | + output = self.model(x) |
| 49 | + |
| 50 | + # Convert logits to probabilities assuming a multi-class classification. |
| 51 | + probabilities = F.softmax(output, dim=1).squeeze() |
| 52 | + opts = self.integration_options |
| 53 | + |
| 54 | + # Prepare threshold tensor. |
| 55 | + threshold = torch.tensor(opts.threshold, dtype=probabilities.dtype, device=probabilities.device) |
| 56 | + condition = probabilities > threshold |
| 57 | + |
| 58 | + if opts.snap_value is not None: |
| 59 | + snap_value = torch.tensor(opts.snap_value, dtype=probabilities.dtype, device=probabilities.device) |
| 60 | + # For values that pass the threshold: |
| 61 | + lower_val = snap_value if opts.set_lower_bound else torch.tensor(0.0, dtype=probabilities.dtype, |
| 62 | + device=probabilities.device) |
| 63 | + upper_val = snap_value if opts.set_upper_bound else torch.tensor(1.0, dtype=probabilities.dtype, |
| 64 | + device=probabilities.device) |
| 65 | + else: |
| 66 | + # If no snap_value is provided, keep original probabilities for those passing threshold. |
| 67 | + lower_val = probabilities |
| 68 | + upper_val = probabilities |
| 69 | + |
| 70 | + # For probabilities that pass the threshold, apply the above; else, bounds are fixed to [0,1]. |
| 71 | + lower_bounds = torch.where(condition, lower_val, torch.zeros_like(probabilities)) |
| 72 | + upper_bounds = torch.where(condition, upper_val, torch.ones_like(probabilities)) |
| 73 | + |
| 74 | + # Convert bounds to Python floats for fact creation. |
| 75 | + bounds_list = [] |
| 76 | + for i in range(len(self.class_names)): |
| 77 | + lower = lower_bounds[i].item() |
| 78 | + upper = upper_bounds[i].item() |
| 79 | + bounds_list.append([lower, upper]) |
| 80 | + |
| 81 | + # Define time bounds for the facts. |
| 82 | + facts = [] |
| 83 | + for class_name, bounds in zip(self.class_names, bounds_list): |
| 84 | + lower, upper = bounds |
| 85 | + fact_str = f'{self.model_name}({class_name}) : [{lower:.3f}, {upper:.3f}])' |
| 86 | + fact = Fact(fact_str, name=f'{self.model_name}-{class_name}-fact', start_time=t1, end_time=t2) |
| 87 | + facts.append(fact) |
| 88 | + return output, facts |
| 89 | + |
0 commit comments