diff --git a/difflogic/difflogic.py b/difflogic/difflogic.py index bd2310c..a722de7 100644 --- a/difflogic/difflogic.py +++ b/difflogic/difflogic.py @@ -1,10 +1,21 @@ import torch -import difflogic_cuda +from pallas import autograd as pallas_autograd import numpy as np +import jax +import jax.numpy as jnp +from jax.nn import softmax, one_hot from .functional import bin_op_s, get_unique_connections, GradFactor from .packbitstensor import PackBitsTensor +from pallas import compile as pallas_compile + +@pallas_compile(backend="cuda") +def logic_layer_kernel(x, a, b, w, y): + i = pallas_autograd.program_id(0) + y[i] = x[i] + + ######################################################################################################################## @@ -53,30 +64,14 @@ def __init__( assert self.connections in ['random', 'unique'], self.connections self.indices = self.get_connections(self.connections, device) - if self.implementation == 'cuda': - """ - Defining additional indices for improving the efficiency of the backward of the CUDA implementation. - """ - given_x_indices_of_y = [[] for _ in range(in_dim)] - indices_0_np = self.indices[0].cpu().numpy() - indices_1_np = self.indices[1].cpu().numpy() - for y in range(out_dim): - given_x_indices_of_y[indices_0_np[y]].append(y) - given_x_indices_of_y[indices_1_np[y]].append(y) - self.given_x_indices_of_y_start = torch.tensor( - np.array([0] + [len(g) for g in given_x_indices_of_y]).cumsum(), device=device, dtype=torch.int64) - self.given_x_indices_of_y = torch.tensor( - [item for sublist in given_x_indices_of_y for item in sublist], dtype=torch.int64, device=device) - self.num_neurons = out_dim self.num_weights = out_dim - def forward(self, x): + def forward(self, x, training: bool): if isinstance(x, PackBitsTensor): - assert not self.training, 'PackBitsTensor is not supported for the differentiable training mode.' - assert self.device == 'cuda', 'PackBitsTensor is only supported for CUDA, not for {}. ' \ - 'If you want fast inference on CPU, please use CompiledDiffLogicModel.' \ - ''.format(self.device) + assert not training, 'PackBitsTensor is not supported for the differentiable training mode.' + assert self.device == 'cuda', 'PackBitsTensor is only supported for CUDA, not for {}. '.format(self.device) + \ + 'If you want fast inference on CPU, please use CompiledDiffLogicModel.' else: if self.grad_factor != 1.: @@ -85,68 +80,41 @@ def forward(self, x): if self.implementation == 'cuda': if isinstance(x, PackBitsTensor): return self.forward_cuda_eval(x) - return self.forward_cuda(x) + return self.forward_cuda(x, training) elif self.implementation == 'python': - return self.forward_python(x) + return self.forward_python(x, training) else: raise ValueError(self.implementation) - def forward_python(self, x): - assert x.shape[-1] == self.in_dim, (x[0].shape[-1], self.in_dim) - - if self.indices[0].dtype == torch.int64 or self.indices[1].dtype == torch.int64: - print(self.indices[0].dtype, self.indices[1].dtype) - self.indices = self.indices[0].long(), self.indices[1].long() - print(self.indices[0].dtype, self.indices[1].dtype) + def forward_python(self, x, training: bool): + assert x.shape[-1] == self.in_dim, (x.shape[-1], self.in_dim) a, b = x[..., self.indices[0]], x[..., self.indices[1]] - if self.training: - x = bin_op_s(a, b, torch.nn.functional.softmax(self.weights, dim=-1)) + weights = jnp.array(self.weights) # Convert to JAX array + if training: + x = bin_op_s(a, b, softmax(weights, axis=-1)) else: - weights = torch.nn.functional.one_hot(self.weights.argmax(-1), 16).to(torch.float32) + weights = one_hot(jnp.argmax(weights, axis=-1), 16).astype(jnp.float32) x = bin_op_s(a, b, weights) return x - def forward_cuda(self, x): - if self.training: - assert x.device.type == 'cuda', x.device - assert x.ndim == 2, x.ndim + def forward_cuda(self, x, training: bool): + x = jnp.array(x) + a = jnp.array(self.indices[0]) + b = jnp.array(self.indices[1]) + w = jnp.array(self.weights) + y = jnp.zeros((x.shape[0], self.out_dim), dtype=x.dtype) - x = x.transpose(0, 1) - x = x.contiguous() + grid_dim = (x.shape[0],) + block_dim = (min(x.shape[0], 1024),) - assert x.shape[0] == self.in_dim, (x.shape, self.in_dim) + logic_layer_kernel[grid_dim, block_dim](x[:,0], a, b, w, y) - a, b = self.indices - - if self.training: - w = torch.nn.functional.softmax(self.weights, dim=-1).to(x.dtype) - return LogicLayerCudaFunction.apply( - x, a, b, w, self.given_x_indices_of_y_start, self.given_x_indices_of_y - ).transpose(0, 1) - else: - w = torch.nn.functional.one_hot(self.weights.argmax(-1), 16).to(x.dtype) - with torch.no_grad(): - return LogicLayerCudaFunction.apply( - x, a, b, w, self.given_x_indices_of_y_start, self.given_x_indices_of_y - ).transpose(0, 1) + return y def forward_cuda_eval(self, x: PackBitsTensor): - """ - WARNING: this is an in-place operation. - - :param x: - :return: - """ - assert not self.training - assert isinstance(x, PackBitsTensor) - assert x.t.shape[0] == self.in_dim, (x.t.shape, self.in_dim) - - a, b = self.indices - w = self.weights.argmax(-1).to(torch.uint8) - x.t = difflogic_cuda.eval(x.t, a, b, w) - - return x + raise NotImplementedError("`forward_cuda_eval` is not yet implemented for the JAX version. " + "PackBitsTensor is currently not supported in JAX.") def extra_repr(self): return '{}, {}, {}'.format(self.in_dim, self.out_dim, 'train' if self.training else 'eval') @@ -172,53 +140,22 @@ def get_connections(self, connections, device='cuda'): ######################################################################################################################## -class GroupSum(torch.nn.Module): +class GroupSum: """ The GroupSum module. """ def __init__(self, k: int, tau: float = 1., device='cuda'): - """ - - :param k: number of intended real valued outputs, e.g., number of classes - :param tau: the (softmax) temperature tau. The summed outputs are divided by tau. - :param device: - """ - super().__init__() self.k = k self.tau = tau self.device = device def forward(self, x): if isinstance(x, PackBitsTensor): - return x.group_sum(self.k) + raise NotImplementedError("PackBitsTensor is not yet supported in JAX.") assert x.shape[-1] % self.k == 0, (x.shape, self.k) - return x.reshape(*x.shape[:-1], self.k, x.shape[-1] // self.k).sum(-1) / self.tau - - def extra_repr(self): - return 'k={}, tau={}'.format(self.k, self.tau) - + return x.reshape(*x.shape[:-1], self.k, x.shape[-1] // self.k).sum(axis=-1) / self.tau -######################################################################################################################## - - -class LogicLayerCudaFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, x, a, b, w, given_x_indices_of_y_start, given_x_indices_of_y): - ctx.save_for_backward(x, a, b, w, given_x_indices_of_y_start, given_x_indices_of_y) - return difflogic_cuda.forward(x, a, b, w) - - @staticmethod - def backward(ctx, grad_y): - x, a, b, w, given_x_indices_of_y_start, given_x_indices_of_y = ctx.saved_tensors - grad_y = grad_y.contiguous() + def __repr__(self): + return f'GroupSum(k={self.k}, tau={self.tau})' - grad_w = grad_x = None - if ctx.needs_input_grad[0]: - grad_x = difflogic_cuda.backward_x(x, a, b, w, grad_y, given_x_indices_of_y_start, given_x_indices_of_y) - if ctx.needs_input_grad[3]: - grad_w = difflogic_cuda.backward_w(x, a, b, grad_y) - return grad_x, None, None, grad_w, None, None, None - - -######################################################################################################################## diff --git a/difflogic/functional.py b/difflogic/functional.py index 1fd181c..d842270 100644 --- a/difflogic/functional.py +++ b/difflogic/functional.py @@ -1,4 +1,4 @@ -import torch +import jax.numpy as jnp import numpy as np BITS_TO_NP_DTYPE = {8: np.int8, 16: np.int16, 32: np.int32, 64: np.int64} @@ -29,7 +29,7 @@ def bin_op(a, b, i): assert a[1].shape == b[1].shape, (a[1].shape, b[1].shape) if i == 0: - return torch.zeros_like(a) + return jnp.zeros_like(a) elif i == 1: return a * b elif i == 2: @@ -59,11 +59,11 @@ def bin_op(a, b, i): elif i == 14: return 1 - a * b elif i == 15: - return torch.ones_like(a) + return jnp.ones_like(a) def bin_op_s(a, b, i_s): - r = torch.zeros_like(a) + r = jnp.zeros_like(a) for i in range(16): u = bin_op(a, b, i) r = r + i_s[..., i] * u diff --git a/difflogic/tests/test_jax.py b/difflogic/tests/test_jax.py new file mode 100644 index 0000000..d56bb74 --- /dev/null +++ b/difflogic/tests/test_jax.py @@ -0,0 +1,26 @@ +import jax +import jax.numpy as jnp +from jax.nn import softmax, one_hot +from difflogic import LogicLayer, GroupSum +from difflogic.functional import bin_op_s +import numpy as np + +def test_logic_layer_forward(): + in_dim = 4 + out_dim = 2 + layer = LogicLayer(in_dim=in_dim, out_dim=out_dim, implementation="python", connections="unique") + x = jnp.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]) + layer.weights = np.random.randn(out_dim, 16) + output = layer(x, training=True) + assert output.shape == (2, 2) + +def test_group_sum_forward(): + k = 2 + group_sum = GroupSum(k=k) + x = jnp.array([[1., 2., 3., 4.], [5., 6., 7., 8.]]) + output = group_sum.forward(x) + jnp.testing.assert_allclose(output, jnp.array([[4., 6.], [12., 14.]])) + + +test_logic_layer_forward() +test_group_sum_forward() diff --git a/experiments/main.py b/experiments/main.py index b55bde2..6b82209 100644 --- a/experiments/main.py +++ b/experiments/main.py @@ -4,8 +4,191 @@ import os import numpy as np -import torch -import torchvision +import jax +import jax.numpy as jnp +from flax import linen as nn +import optax +from tqdm import tqdm + +from results_json import ResultsJSON + +from difflogic import LogicLayer, GroupSum + + +def load_dataset(args): + # Placeholder for data loading - replace with actual data loading later + if args.dataset == "mnist": + num_classes = 10 + key = jax.random.PRNGKey(0) + train_data = jax.random.normal(key, (50000, 784)) + train_labels = jax.random.randint(key, (50000,), 0, num_classes) + test_data = jax.random.normal(key, (10000, 784)) + test_labels = jax.random.randint(key, (10000,), 0, num_classes) + + return (train_data, train_labels), (test_data, test_labels), None + else: + raise NotImplementedError(f"Dataset {args.dataset} not yet supported for JAX.") + + + +def input_dim_of_dataset(dataset): + return { + 'mnist': 784, + }[dataset] + + +def num_classes_of_dataset(dataset): + return { + 'mnist': 10, + }[dataset] + + + + +class Model(nn.Module): + in_dim: int + out_dim: int + num_layers: int + num_classes: int + tau: float + grad_factor: float + connections: str + + @nn.compact + def __call__(self, x, training: bool): + llkw = dict(grad_factor=self.grad_factor, connections=self.connections, implementation="python") + + x = x.reshape((x.shape[0], -1)) # Flatten the input + + for _ in range(self.num_layers): + x = LogicLayer(in_dim=self.in_dim if _ == 0 else self.out_dim, out_dim=self.out_dim, **llkw)(x, training=training) + x = GroupSum(k=self.num_classes, tau=self.tau)(x) + return x + + +def get_model(args): + key = jax.random.PRNGKey(0) + + in_dim = input_dim_of_dataset(args.dataset) + class_count = num_classes_of_dataset(args.dataset) + + model = Model(in_dim=in_dim, out_dim=args.num_neurons, num_layers=args.num_layers, num_classes=class_count, tau=args.tau, grad_factor=args.grad_factor, connections=args.connections) + + params = model.init(key, jnp.ones((1, in_dim)), training=True) + + # Placeholder loss function + def loss_fn(params, x, y): + logits = model.apply(params, x, training=True) + return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean() + + optimizer = optax.adam(learning_rate=args.learning_rate) + + return model, params, loss_fn, optimizer + + +@jax.jit +def train_step(params, opt_state, x, y, loss_fn): + loss, grads = jax.value_and_grad(loss_fn)(params, x, y) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state, loss + +def eval(params, model, x, y): + logits = model.apply(params, x, training=False) + accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y) + return accuracy + + +def packbits_eval(model, loader): + raise NotImplementedError("PackBitsTensor is not yet supported in JAX.") + + +if __name__ == '__main__': + + #################################################################################################################### + + parser = argparse.ArgumentParser(description='Train logic gate network on the various datasets.') + + parser.add_argument('-eid', '--experiment_id', type=int, default=None) + + parser.add_argument('--dataset', type=str, choices=[ + 'mnist', + ], required=True, help='the dataset to use') + parser.add_argument('--tau', '-t', type=float, default=10, help='the softmax temperature tau') + parser.add_argument('--seed', '-s', type=int, default=0, help='seed (default: 0)') + parser.add_argument('--batch-size', '-bs', type=int, default=128, help='batch size (default: 128)') + parser.add_argument('--learning-rate', '-lr', type=float, default=0.01, help='learning rate (default: 0.01)') + + + parser.add_argument('--implementation', type=str, default='python', choices=['cuda', 'python'], + help='`cuda` is the fast CUDA implementation and `python` is simpler but much slower ' + 'implementation intended for helping with the understanding.') + + + parser.add_argument('--num-iterations', '-ni', type=int, default=1000, help='Number of iterations (default: 1000)') + parser.add_argument('--eval-freq', '-ef', type=int, default=200, help='Evaluation frequency (default: 200)') + + + + parser.add_argument('--connections', type=str, default='unique', choices=['random', 'unique']) + parser.add_argument('--architecture', '-a', type=str, default='randomly_connected') + parser.add_argument('--num_neurons', '-k', type=int) + parser.add_argument('--num_layers', '-l', type=int) + + parser.add_argument('--grad-factor', type=float, default=1.) + + args = parser.parse_args() + + #################################################################################################################### + + print(vars(args)) + + assert args.num_iterations % args.eval_freq == 0, ( + f'iteration count ({args.num_iterations}) has to be divisible by evaluation frequency ({args.eval_freq})' + ) + + + key = jax.random.PRNGKey(args.seed) + (train_data, train_labels), (test_data, test_labels), _ = load_dataset(args) + model, params, loss_fn, optimizer = get_model(args) + + opt_state = optimizer.init(params) + + + best_acc = 0 + + for i in tqdm(range(args.num_iterations), desc='iteration', total=args.num_iterations): + x = train_data[i % len(train_data)].reshape((1, -1)) + y = train_labels[i % len(train_data)].reshape((1,)) + + params, opt_state, loss = train_step(params, opt_state, x, y, loss_fn) + + + if (i+1) % args.eval_freq == 0: + + train_accuracy = eval(params, model, train_data, train_labels) + test_accuracy = eval(params, model, test_data, test_labels) + + r = { + 'train_acc': train_accuracy, + 'test_acc': test_accuracy, + } + + print(r) + + if test_accuracy > best_acc: + best_acc = test_accuracy + print('IS THE BEST UNTIL NOW.') + +import math +import random +import os + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +import optax from tqdm import tqdm from results_json import ResultsJSON @@ -14,68 +197,23 @@ import uci_datasets from difflogic import LogicLayer, GroupSum, PackBitsTensor, CompiledLogicNet -torch.set_num_threads(1) - -BITS_TO_TORCH_FLOATING_POINT_TYPE = { - 16: torch.float16, - 32: torch.float32, - 64: torch.float64 -} - def load_dataset(args): - validation_loader = None - if args.dataset == 'adult': - train_set = uci_datasets.AdultDataset('./data-uci', split='train', download=True, with_val=False) - test_set = uci_datasets.AdultDataset('./data-uci', split='test', with_val=False) - train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True) - test_loader = torch.utils.data.DataLoader(test_set, batch_size=int(1e6), shuffle=False) - elif args.dataset == 'breast_cancer': - train_set = uci_datasets.BreastCancerDataset('./data-uci', split='train', download=True, with_val=False) - test_set = uci_datasets.BreastCancerDataset('./data-uci', split='test', with_val=False) - train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True) - test_loader = torch.utils.data.DataLoader(test_set, batch_size=int(1e6), shuffle=False) - elif args.dataset.startswith('monk'): - style = int(args.dataset[4]) - train_set = uci_datasets.MONKsDataset('./data-uci', style, split='train', download=True, with_val=False) - test_set = uci_datasets.MONKsDataset('./data-uci', style, split='test', with_val=False) - train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True) - test_loader = torch.utils.data.DataLoader(test_set, batch_size=int(1e6), shuffle=False) - elif args.dataset in ['mnist', 'mnist20x20']: - train_set = mnist_dataset.MNIST('./data-mnist', train=True, download=True, remove_border=args.dataset == 'mnist20x20') - test_set = mnist_dataset.MNIST('./data-mnist', train=False, remove_border=args.dataset == 'mnist20x20') - - train_set_size = math.ceil((1 - args.valid_set_size) * len(train_set)) - valid_set_size = len(train_set) - train_set_size - train_set, validation_set = torch.utils.data.random_split(train_set, [train_set_size, valid_set_size]) - - train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=4) - validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=True) - test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=True) - elif 'cifar-10' in args.dataset: - transform = { - 'cifar-10-3-thresholds': lambda x: torch.cat([(x > (i + 1) / 4).float() for i in range(3)], dim=0), - 'cifar-10-31-thresholds': lambda x: torch.cat([(x > (i + 1) / 32).float() for i in range(31)], dim=0), - }[args.dataset] - transforms = torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Lambda(transform), - ]) - train_set = torchvision.datasets.CIFAR10('./data-cifar', train=True, download=True, transform=transforms) - test_set = torchvision.datasets.CIFAR10('./data-cifar', train=False, transform=transforms) - - train_set_size = math.ceil((1 - args.valid_set_size) * len(train_set)) - valid_set_size = len(train_set) - train_set_size - train_set, validation_set = torch.utils.data.random_split(train_set, [train_set_size, valid_set_size]) - - train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=4) - validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=True) - test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, pin_memory=True, drop_last=True) - + # Placeholder for data loading - replace with actual data loading later + if args.dataset == "mnist": + num_classes = 10 + key = jax.random.PRNGKey(0) + train_data = jax.random.normal(key, (50000, 784)) + train_labels = jax.random.randint(key, (50000,), 0, num_classes) + test_data = jax.random.normal(key, (10000, 784)) + test_labels = jax.random.randint(key, (10000,), 0, num_classes) + + return (train_data, train_labels), (test_data, test_labels), None else: - raise NotImplementedError(f'The data set {args.dataset} is not supported!') + raise NotImplementedError(f"Dataset {args.dataset} not yet supported for JAX.") + + - return train_loader, validation_loader, test_loader def load_n(loader, n): @@ -116,59 +254,52 @@ def num_classes_of_dataset(dataset): }[dataset] -def get_model(args): - llkw = dict(grad_factor=args.grad_factor, connections=args.connections) - - in_dim = input_dim_of_dataset(args.dataset) - class_count = num_classes_of_dataset(args.dataset) - - logic_layers = [] +import jax +import jax.numpy as jnp +from flax import linen as nn +import optax +from difflogic import LogicLayer, GroupSum - arch = args.architecture - k = args.num_neurons - l = args.num_layers +class Model(nn.Module): + in_dim: int + out_dim: int + num_layers: int + num_classes: int + tau: float + grad_factor: float + connections: str - #################################################################################################################### + @nn.compact + def __call__(self, x, training: bool): + llkw = dict(grad_factor=self.grad_factor, connections=self.connections, implementation="python") - if arch == 'randomly_connected': - logic_layers.append(torch.nn.Flatten()) - logic_layers.append(LogicLayer(in_dim=in_dim, out_dim=k, **llkw)) - for _ in range(l - 1): - logic_layers.append(LogicLayer(in_dim=k, out_dim=k, **llkw)) + x = x.reshape((x.shape[0], -1)) # Flatten the input - model = torch.nn.Sequential( - *logic_layers, - GroupSum(class_count, args.tau) - ) + for _ in range(self.num_layers): + x = LogicLayer(in_dim=self.in_dim if _ == 0 else self.out_dim, out_dim=self.out_dim, **llkw)(x, training=training) + x = GroupSum(k=self.num_classes, tau=self.tau)(x) + return x - #################################################################################################################### - else: - raise NotImplementedError(arch) +def get_model(args): + key = jax.random.PRNGKey(0) - #################################################################################################################### + in_dim = input_dim_of_dataset(args.dataset) + class_count = num_classes_of_dataset(args.dataset) - total_num_neurons = sum(map(lambda x: x.num_neurons, logic_layers[1:-1])) - print(f'total_num_neurons={total_num_neurons}') - total_num_weights = sum(map(lambda x: x.num_weights, logic_layers[1:-1])) - print(f'total_num_weights={total_num_weights}') - if args.experiment_id is not None: - results.store_results({ - 'total_num_neurons': total_num_neurons, - 'total_num_weights': total_num_weights, - }) + model = Model(in_dim=in_dim, out_dim=args.num_neurons, num_layers=args.num_layers, num_classes=class_count, tau=args.tau, grad_factor=args.grad_factor, connections=args.connections) - model = model.to('cuda') + params = model.init(key, jnp.ones((1, in_dim)), training=True) - print(model) - if args.experiment_id is not None: - results.store_results({'model_str': str(model)}) + # Placeholder loss function + def loss_fn(params, x, y): + logits = model.apply(params, x, training=True) + return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean() - loss_fn = torch.nn.CrossEntropyLoss() + optimizer = optax.adam(learning_rate=args.learning_rate) - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + return model, params, loss_fn, optimizer - return model, loss_fn, optimizer def train(model, x, y, loss_fn, optimizer): diff --git a/experiments/requirements.txt b/experiments/requirements.txt index e8a37ab..25373ff 100644 --- a/experiments/requirements.txt +++ b/experiments/requirements.txt @@ -3,3 +3,4 @@ scikit-learn torch torchvision difflogic +pallas diff --git a/setup.py b/setup.py index 89804c3..42e8f98 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ cmdclass={'build_ext': BuildExtension}, python_requires='>=3.6', install_requires=[ + 'pallas', 'torch>=1.6.0', 'numpy', ],