From 0006bed6426c21dbd871d64fb6055310ca91ef1e Mon Sep 17 00:00:00 2001 From: nazar Date: Mon, 25 Mar 2019 22:04:03 +0300 Subject: [PATCH 1/3] libtorch C++ jni --- README.md | 446 +- build.sbt | 82 +- project/Dependencies.scala | 6 - project/JniBuildPlugin.scala | 48 + project/JniGeneratorPlugin.scala | 125 + project/build.properties | 1 - project/plugins.sbt | 25 +- src/main/java/generate/Builder.java | 691 ++ src/main/java/generate/ClassScanner.java | 135 + src/main/java/generate/UserClassLoader.java | 74 + .../org/bytedeco/javacpp/presets/torch.java | 13 + src/main/java/torch_java/api/Functions.java | 52 + src/main/java/torch_java/api/Tensor.java | 606 ++ src/main/java/torch_java/api/nn/Module.java | 25 + .../java/torch_java/examples/FourierNet.java | 57 + src/main/resources/logback.xml | 35 - src/main/scala/scorch/autograd/Function.scala | 358 - src/main/scala/scorch/autograd/Variable.scala | 85 - .../data/loader/Cifar10DataLoader.scala | 171 - .../scala/scorch/data/loader/DataLoader.scala | 25 - .../scorch/data/loader/MnistDataLoader.scala | 62 - src/main/scala/scorch/nn/BatchNorm.scala | 127 - src/main/scala/scorch/nn/Dropout.scala | 66 - src/main/scala/scorch/nn/Linear.scala | 24 - src/main/scala/scorch/nn/Module.scala | 85 - src/main/scala/scorch/nn/ParallelModule.scala | 53 - src/main/scala/scorch/nn/cnn/Conv2d.scala | 307 - src/main/scala/scorch/nn/cnn/MaxPool2d.scala | 186 - src/main/scala/scorch/nn/rnn/GruCell.scala | 66 - src/main/scala/scorch/nn/rnn/LstmCell.scala | 89 - src/main/scala/scorch/nn/rnn/RnnBase.scala | 108 - src/main/scala/scorch/nn/rnn/RnnCell.scala | 52 - .../scala/scorch/nn/rnn/RnnCellBase.scala | 32 - src/main/scala/scorch/optim/Adam.scala | 36 - src/main/scala/scorch/optim/DCASGDa.scala | 40 - src/main/scala/scorch/optim/Nesterov.scala | 19 - src/main/scala/scorch/optim/Optimizer.scala | 10 - src/main/scala/scorch/optim/SGD.scala | 11 - src/main/scala/scorch/package.scala | 50 - .../scala/scorch/sandbox/MnistWrangler.scala | 87 - .../scala/scorch/sandbox/ReadmeConvNet.scala | 85 - .../scala/scorch/sandbox/cnn/LeNet5.scala | 214 - .../sandbox/rnn/DinosaurIslandCharRnn.scala | 509 - .../scorch/sandbox/rnn/LanguageModel.scala | 150 - src/main/scala/scorch/sandbox/rnn/Rnn.scala | 80 - .../scorch/sandbox/rnn/TemporalAffine.scala | 35 - .../scorch/sandbox/rnn/TemporalSoftmax.scala | 61 - .../scorch/sandbox/rnn/WordEmbedding.scala | 70 - src/main/scala/torch_scala/Torch.scala | 38 + .../scala/torch_scala/api/aten/ArrayRef.scala | 123 + .../scala/torch_scala/api/aten/Device.scala | 32 + .../scala/torch_scala/api/aten/Indexer.scala | 462 + .../api/aten/PrimitivePointer.scala | 27 + .../scala/torch_scala/api/aten/Scalar.scala | 37 + .../scala/torch_scala/api/aten/Shape.scala | 382 + .../scala/torch_scala/api/aten/Tensor.scala | 427 + .../torch_scala/api/aten/TensorOptions.scala | 67 + .../torch_scala/api/aten/TensorType.scala | 10 + .../torch_scala/api/aten/TensorVector.scala | 53 + .../api/aten/functions/Basic.scala | 44 + .../api/aten/functions/Functions.scala | 81 + .../torch_scala/api/aten/functions/Math.scala | 86 + .../api/aten/functions/MathBackward.scala | 18 + .../scala/torch_scala/api/nn/Module.scala | 21 + src/main/scala/torch_scala/api/package.scala | 120 + .../torch_scala/api/types/DataType.scala | 171 + .../scala/torch_scala/api/types/types.scala | 231 + .../scala/torch_scala/apps/FourierNet.scala | 75 + .../scala/torch_scala/autograd/Function.scala | 408 + .../torch_scala/autograd/MathVariable.scala | 63 + .../scala/torch_scala/autograd/Variable.scala | 56 + .../torch_scala/examples/FourierNet.scala | 27 + .../exceptions/TorchExceptions.scala | 148 + .../native_generator/Generate.scala | 25 + src/main/scala/torch_scala/nn/Linear.scala | 28 + src/main/scala/torch_scala/nn/Module.scala | 88 + src/main/scala/torch_scala/optim/Adam.scala | 43 + .../optim/DCASGD.scala | 23 +- .../scala/torch_scala/optim/DCASGDa.scala | 39 + .../scala/torch_scala/optim/Nesterov.scala | 18 + .../scala/torch_scala/optim/Optimizer.scala | 10 + src/main/scala/torch_scala/optim/SGD.scala | 12 + src/native/CMakeLists.txt | 42 + .../detect_cuda_compute_capabilities.cpp | 15 + src/native/detect_cuda_version.cc | 6 + src/native/helper.h | 70 + src/native/java_torch_lib.cpp | 6089 +++++++++++ src/native/jnijavacpp.cpp | 2113 ++++ src/native/libjava_torch_lib0.so | Bin 0 -> 963424 bytes src/native/models/FourierNet.cpp | 8 + src/native/models/FourierNet.h | 103 + src/test/resources/dinos.txt | 1536 --- src/test/resources/names/Arabic.txt | 2000 ---- src/test/resources/names/Chinese.txt | 268 - src/test/resources/names/Czech.txt | 519 - src/test/resources/names/Dutch.txt | 297 - src/test/resources/names/English.txt | 3668 ------- src/test/resources/names/French.txt | 277 - src/test/resources/names/German.txt | 724 -- src/test/resources/names/Greek.txt | 203 - src/test/resources/names/Irish.txt | 232 - src/test/resources/names/Italian.txt | 709 -- src/test/resources/names/Japanese.txt | 991 -- src/test/resources/names/Korean.txt | 94 - src/test/resources/names/Polish.txt | 139 - src/test/resources/names/Portuguese.txt | 74 - src/test/resources/names/Russian.txt | 9408 ----------------- src/test/resources/names/Scottish.txt | 100 - src/test/resources/names/Spanish.txt | 298 - src/test/resources/names/Vietnamese.txt | 73 - src/test/resources/sonnets-cleaned.txt | 2155 ---- src/test/scala/scorch/TestUtil.scala | 155 - .../scala/scorch/autograd/AutoGradSpec.scala | 154 - .../autograd/FunctionGradientSpec.scala | 196 - .../scala/scorch/autograd/FunctionSpec.scala | 206 - .../scorch/data/loader/DataLoaderSpec.scala | 100 - src/test/scala/scorch/nn/BatchNormSpec.scala | 292 - src/test/scala/scorch/nn/ModuleSpec.scala | 309 - .../scala/scorch/nn/ParallelModuleSpec.scala | 95 - src/test/scala/scorch/nn/cnn/Conv2dSpec.scala | 148 - .../scala/scorch/nn/cnn/MaxPool2dSpec.scala | 90 - .../scorch/sandbox/rnn/CharRnnSpec.scala | 202 - .../scala/scorch/sandbox/rnn/RnnSpec.scala | 186 - .../sandbox/rnn/TemporalAffineSpec.scala | 65 - .../sandbox/rnn/TemporalSoftmaxSpec.scala | 60 - .../sandbox/rnn/WordEmbeddingSpec.scala | 159 - src/test/scala/torch_scala/TestUtil.scala | 145 + .../scala/torch_scala/aten/TensorSpec.scala | 65 + .../torch_scala/autograd/AutoGradSpec.scala | 157 + .../autograd/FunctionGradientSpec.scala | 208 + .../torch_scala/autograd/FunctionSpec.scala | 51 + .../scala/torch_scala/nn/ModuleSpec.scala | 206 + 132 files changed, 14646 insertions(+), 30081 deletions(-) delete mode 100644 project/Dependencies.scala create mode 100644 project/JniBuildPlugin.scala create mode 100644 project/JniGeneratorPlugin.scala delete mode 100644 project/build.properties create mode 100644 src/main/java/generate/Builder.java create mode 100644 src/main/java/generate/ClassScanner.java create mode 100644 src/main/java/generate/UserClassLoader.java create mode 100644 src/main/java/org/bytedeco/javacpp/presets/torch.java create mode 100644 src/main/java/torch_java/api/Functions.java create mode 100644 src/main/java/torch_java/api/Tensor.java create mode 100644 src/main/java/torch_java/api/nn/Module.java create mode 100644 src/main/java/torch_java/examples/FourierNet.java delete mode 100644 src/main/resources/logback.xml delete mode 100644 src/main/scala/scorch/autograd/Function.scala delete mode 100644 src/main/scala/scorch/autograd/Variable.scala delete mode 100644 src/main/scala/scorch/data/loader/Cifar10DataLoader.scala delete mode 100644 src/main/scala/scorch/data/loader/DataLoader.scala delete mode 100644 src/main/scala/scorch/data/loader/MnistDataLoader.scala delete mode 100644 src/main/scala/scorch/nn/BatchNorm.scala delete mode 100644 src/main/scala/scorch/nn/Dropout.scala delete mode 100644 src/main/scala/scorch/nn/Linear.scala delete mode 100644 src/main/scala/scorch/nn/Module.scala delete mode 100644 src/main/scala/scorch/nn/ParallelModule.scala delete mode 100644 src/main/scala/scorch/nn/cnn/Conv2d.scala delete mode 100644 src/main/scala/scorch/nn/cnn/MaxPool2d.scala delete mode 100644 src/main/scala/scorch/nn/rnn/GruCell.scala delete mode 100644 src/main/scala/scorch/nn/rnn/LstmCell.scala delete mode 100644 src/main/scala/scorch/nn/rnn/RnnBase.scala delete mode 100644 src/main/scala/scorch/nn/rnn/RnnCell.scala delete mode 100644 src/main/scala/scorch/nn/rnn/RnnCellBase.scala delete mode 100644 src/main/scala/scorch/optim/Adam.scala delete mode 100644 src/main/scala/scorch/optim/DCASGDa.scala delete mode 100644 src/main/scala/scorch/optim/Nesterov.scala delete mode 100644 src/main/scala/scorch/optim/Optimizer.scala delete mode 100644 src/main/scala/scorch/optim/SGD.scala delete mode 100644 src/main/scala/scorch/package.scala delete mode 100644 src/main/scala/scorch/sandbox/MnistWrangler.scala delete mode 100644 src/main/scala/scorch/sandbox/ReadmeConvNet.scala delete mode 100644 src/main/scala/scorch/sandbox/cnn/LeNet5.scala delete mode 100644 src/main/scala/scorch/sandbox/rnn/DinosaurIslandCharRnn.scala delete mode 100644 src/main/scala/scorch/sandbox/rnn/LanguageModel.scala delete mode 100644 src/main/scala/scorch/sandbox/rnn/Rnn.scala delete mode 100644 src/main/scala/scorch/sandbox/rnn/TemporalAffine.scala delete mode 100644 src/main/scala/scorch/sandbox/rnn/TemporalSoftmax.scala delete mode 100644 src/main/scala/scorch/sandbox/rnn/WordEmbedding.scala create mode 100644 src/main/scala/torch_scala/Torch.scala create mode 100644 src/main/scala/torch_scala/api/aten/ArrayRef.scala create mode 100644 src/main/scala/torch_scala/api/aten/Device.scala create mode 100644 src/main/scala/torch_scala/api/aten/Indexer.scala create mode 100644 src/main/scala/torch_scala/api/aten/PrimitivePointer.scala create mode 100644 src/main/scala/torch_scala/api/aten/Scalar.scala create mode 100644 src/main/scala/torch_scala/api/aten/Shape.scala create mode 100644 src/main/scala/torch_scala/api/aten/Tensor.scala create mode 100644 src/main/scala/torch_scala/api/aten/TensorOptions.scala create mode 100644 src/main/scala/torch_scala/api/aten/TensorType.scala create mode 100644 src/main/scala/torch_scala/api/aten/TensorVector.scala create mode 100644 src/main/scala/torch_scala/api/aten/functions/Basic.scala create mode 100644 src/main/scala/torch_scala/api/aten/functions/Functions.scala create mode 100644 src/main/scala/torch_scala/api/aten/functions/Math.scala create mode 100644 src/main/scala/torch_scala/api/aten/functions/MathBackward.scala create mode 100644 src/main/scala/torch_scala/api/nn/Module.scala create mode 100644 src/main/scala/torch_scala/api/package.scala create mode 100644 src/main/scala/torch_scala/api/types/DataType.scala create mode 100644 src/main/scala/torch_scala/api/types/types.scala create mode 100644 src/main/scala/torch_scala/apps/FourierNet.scala create mode 100644 src/main/scala/torch_scala/autograd/Function.scala create mode 100644 src/main/scala/torch_scala/autograd/MathVariable.scala create mode 100644 src/main/scala/torch_scala/autograd/Variable.scala create mode 100644 src/main/scala/torch_scala/examples/FourierNet.scala create mode 100644 src/main/scala/torch_scala/exceptions/TorchExceptions.scala create mode 100644 src/main/scala/torch_scala/native_generator/Generate.scala create mode 100644 src/main/scala/torch_scala/nn/Linear.scala create mode 100644 src/main/scala/torch_scala/nn/Module.scala create mode 100644 src/main/scala/torch_scala/optim/Adam.scala rename src/main/scala/{scorch => torch_scala}/optim/DCASGD.scala (57%) create mode 100644 src/main/scala/torch_scala/optim/DCASGDa.scala create mode 100644 src/main/scala/torch_scala/optim/Nesterov.scala create mode 100644 src/main/scala/torch_scala/optim/Optimizer.scala create mode 100644 src/main/scala/torch_scala/optim/SGD.scala create mode 100644 src/native/CMakeLists.txt create mode 100644 src/native/detect_cuda_compute_capabilities.cpp create mode 100644 src/native/detect_cuda_version.cc create mode 100644 src/native/helper.h create mode 100644 src/native/java_torch_lib.cpp create mode 100644 src/native/jnijavacpp.cpp create mode 100755 src/native/libjava_torch_lib0.so create mode 100644 src/native/models/FourierNet.cpp create mode 100644 src/native/models/FourierNet.h delete mode 100644 src/test/resources/dinos.txt delete mode 100755 src/test/resources/names/Arabic.txt delete mode 100755 src/test/resources/names/Chinese.txt delete mode 100755 src/test/resources/names/Czech.txt delete mode 100755 src/test/resources/names/Dutch.txt delete mode 100755 src/test/resources/names/English.txt delete mode 100755 src/test/resources/names/French.txt delete mode 100755 src/test/resources/names/German.txt delete mode 100755 src/test/resources/names/Greek.txt delete mode 100755 src/test/resources/names/Irish.txt delete mode 100755 src/test/resources/names/Italian.txt delete mode 100755 src/test/resources/names/Japanese.txt delete mode 100755 src/test/resources/names/Korean.txt delete mode 100755 src/test/resources/names/Polish.txt delete mode 100755 src/test/resources/names/Portuguese.txt delete mode 100755 src/test/resources/names/Russian.txt delete mode 100755 src/test/resources/names/Scottish.txt delete mode 100755 src/test/resources/names/Spanish.txt delete mode 100755 src/test/resources/names/Vietnamese.txt delete mode 100644 src/test/resources/sonnets-cleaned.txt delete mode 100644 src/test/scala/scorch/TestUtil.scala delete mode 100644 src/test/scala/scorch/autograd/AutoGradSpec.scala delete mode 100644 src/test/scala/scorch/autograd/FunctionGradientSpec.scala delete mode 100644 src/test/scala/scorch/autograd/FunctionSpec.scala delete mode 100644 src/test/scala/scorch/data/loader/DataLoaderSpec.scala delete mode 100644 src/test/scala/scorch/nn/BatchNormSpec.scala delete mode 100644 src/test/scala/scorch/nn/ModuleSpec.scala delete mode 100644 src/test/scala/scorch/nn/ParallelModuleSpec.scala delete mode 100644 src/test/scala/scorch/nn/cnn/Conv2dSpec.scala delete mode 100644 src/test/scala/scorch/nn/cnn/MaxPool2dSpec.scala delete mode 100644 src/test/scala/scorch/sandbox/rnn/CharRnnSpec.scala delete mode 100644 src/test/scala/scorch/sandbox/rnn/RnnSpec.scala delete mode 100644 src/test/scala/scorch/sandbox/rnn/TemporalAffineSpec.scala delete mode 100644 src/test/scala/scorch/sandbox/rnn/TemporalSoftmaxSpec.scala delete mode 100644 src/test/scala/scorch/sandbox/rnn/WordEmbeddingSpec.scala create mode 100644 src/test/scala/torch_scala/TestUtil.scala create mode 100644 src/test/scala/torch_scala/aten/TensorSpec.scala create mode 100644 src/test/scala/torch_scala/autograd/AutoGradSpec.scala create mode 100644 src/test/scala/torch_scala/autograd/FunctionGradientSpec.scala create mode 100644 src/test/scala/torch_scala/autograd/FunctionSpec.scala create mode 100644 src/test/scala/torch_scala/nn/ModuleSpec.scala diff --git a/README.md b/README.md index 4b21efe..fcbc3b0 100644 --- a/README.md +++ b/README.md @@ -1,441 +1,15 @@ -"What I cannot create, I do not understand." - Richard Feynman. +This library contains JNI and API for Scala with native code from LibTorch. It uses +JavaCPP as automatic code generator. -Scorch -====== -Scorch is a deep learning framework in Scala inspired by PyTorch. +**Installing:** -It has [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) built in -and follows an [imperative coding style](https://mxnet.incubator.apache.org/architecture/program_model.html#symbolic-vs-imperative-programs). - -Scorch uses [numsca](https://github.com/botkop/numsca) for creation and processing of Tensors. - -Here's an example of a convolutional neural net, with relu and pooling followed by 2 affine layers: - -```scala -package scorch.sandbox - -import botkop.{numsca => ns} -import scorch._ -import scorch.autograd.Variable -import scorch.nn.cnn._ -import scorch.nn._ -import scorch.optim.SGD - -object ReadmeConvNet extends App { - - // input layer shape - val (numSamples, numChannels, imageSize) = (8, 3, 32) - val inputShape = List(numSamples, numChannels, imageSize, imageSize) - - // output layer - val numClasses = 10 - - // network blueprint for conv -> relu -> pool -> affine -> affine - case class ConvReluPoolAffineNetwork() extends Module { - - // convolutional layer - val conv = Conv2d(numChannels = 3, numFilters = 32, filterSize = 7, weightScale = 1e-3, pad = 1, stride = 1) - // pooling layer - val pool = MaxPool2d(poolSize = 2, stride = 2) - - // calculate number of flat features - val poolOutShape = pool.outputShape(conv.outputShape(inputShape)) - val numFlatFeatures = poolOutShape.tail.product // all dimensions except the batch dimension - - // reshape from 3d pooling output to 2d affine input - def flatten(v: Variable): Variable = v.reshape(-1, numFlatFeatures) - - // first affine layer - val fc1 = Linear(numFlatFeatures, 100) - // second affine layer (output) - val fc2 = Linear(100, numClasses) - - // chain the layers in a forward pass definition - override def forward(x: Variable): Variable = - x ~> conv ~> relu ~> pool ~> flatten ~> fc1 ~> fc2 - } - - // instantiate the network, and parallelize it - val net = ConvReluPoolAffineNetwork().par() - - // stochastic gradient descent optimizer for updating the parameters - val optimizer = SGD(net.parameters, lr = 0.001) - - // random input and target - val input = Variable(ns.randn(inputShape: _*)) - val target = Variable(ns.randint(numClasses, Array(numSamples, 1))) - - // loop (should reach 100% accuracy in 2 steps) - for (j <- 0 to 3) { - - // reset gradients - optimizer.zeroGrad() - - // forward pass - val output = net(input) - - // calculate the loss - val loss = softmaxLoss(output, target) - - // log accuracy - val guessed = ns.argmax(output.data, axis = 1) - val accuracy = ns.sum(target.data == guessed) / numSamples - println(s"$j: loss: ${loss.data.squeeze()} accuracy: $accuracy") - - // backward pass - loss.backward() - - // update parameters with gradients - optimizer.step() - } -} -``` - -The documentation below is a copy of the Autograd and Neural Networks sections of -[PyTorch blitz](http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html), -adapted for Scorch. - -## Automatic differentiation - -Central to all neural networks in Scorch is the autograd package. -Let’s first briefly visit this, and we will then go to training our first neural network. - -The `autograd` package provides automatic differentiation for all operations on Tensors. -It is a define-by-run framework, which means that your backprop is defined by how your code is run, and that every single iteration can be different. - -### Variable -`autograd.Variable` is the central class of the package. -It wraps a [numsca](https://github.com/botkop/numsca) `Tensor`, and supports nearly all the operations defined on it. -Once you finish your computation you can call `.backward()` and have all the gradients computed automatically. - -You can access the raw tensor through the `.data` attribute, while the gradient w.r.t. this variable is accumulated into `.grad`. - -### Function -There’s one more class which is very important for autograd implementation - a `Function`. - -`Variable` and `Function` are interconnected and build up an acyclic graph, -that encodes a complete history of computation. -Each variable has a `.gradFn` attribute that references the `Function` that has created the `Variable` -(except for Variables created by the user - their `gradFn` is `None`). - -If you want to compute the derivatives, you can call `.backward()` on a `Variable`. -If you do not specify a gradient argument, then Scorch will create one for you on the fly, -of the same shape as the Variable, and filled with all ones. -(This is different from Pytorch) - -```scala -import scorch.autograd.Variable -import botkop.{numsca => ns} -``` -Create a Variable: -```scala -val x = Variable(ns.ones(2,2)) -``` -```text -x: scorch.autograd.Variable = -data: [[1.00, 1.00], - [1.00, 1.00]] -``` -Do an operation on the Variable: -```scala -val y = x + 2 -``` -``` -y: scorch.autograd.Variable = -data: [[3.00, 3.00], - [3.00, 3.00]] -``` -`y` was created as a result of an operation, so it has a `gradFn`. -```scala -println(y.gradFn) -``` -``` -Some(AddConstant(data: [[1.00, 1.00], - [1.00, 1.00]],2.0)) -``` -Do more operations on `y` - -```scala -val z = y * y * 3 -val out = z.mean() -``` -```text -z: scorch.autograd.Variable = -data: [[27.00, 27.00], - [27.00, 27.00]] -out: scorch.autograd.Variable = data: 27.00 -``` -### Gradients -Let’s backprop now, and print gradients d(out)/dx. -```scala -out.backward() -println(x.grad) -``` -```text -data: [[4.50, 4.50], - [4.50, 4.50]] -``` - -## Neural Networks -Neural networks can be constructed using the `scorch.nn` package. - -Now that you had a glimpse of `autograd`, `nn` depends on autograd to define models and differentiate them. -An `nn.Module` contains layers, and a method `forward(input)` that returns the output. - -A typical training procedure for a neural network is as follows: - -* Define the neural network that has some learnable parameters (or weights) -* Iterate over a dataset of inputs -* Process input through the network -* Compute the loss (how far is the output from being correct) -* Propagate gradients back into the network’s parameters -* Update the weights of the network, typically using a simple update rule: - - `weight = weight - learningRate * gradient` - -### Define the network -Let’s define this network: -```scala -import scorch.autograd.Variable -import scorch.nn._ -import scorch._ - -val numSamples = 128 -val numClasses = 10 -val nf1 = 40 -val nf2 = 20 - -// Define a simple neural net -case class Net() extends Module { - val fc1 = Linear(nf1, nf2) // an affine operation: y = Wx + b - val fc2 = Linear(nf2, numClasses) // another one - - // glue the layers with a relu non-linearity: fc1 -> relu -> fc2 - override def forward(x: Variable): Variable = - x ~> fc1 ~> relu ~> fc2 -} - -val net = Net() -``` -You just have to define the forward function. -The backward function (where gradients are computed) is automatically defined for you using autograd. - -The learnable parameters of a model are returned by `net.parameters` - -```scala -val params = net.parameters -println(params.length) -println(params.head.shape) -``` -```text -4 -List(20, 40) -``` -The input to the forward method is an `autograd.Variable`, and so is the output. -```scala -import botkop.{numsca => ns} - -val input = Variable(ns.randn(numSamples, nf1)) -val out = net(input) -println(out) -println(out.shape) -``` -```text -data: [[1.60, -0.22, -0.66, 0.86, -0.59, -0.80, -0.40, -1.37, -1.94, 1.23], - [1.15, -3.81, 5.45, 6.81, -3.02, 2.35, 3.75, 1.79, -7.31, 3.60], - [3.12, -0.94, 2.69, ... +Download LibTorch from (https://pytorch.org) for example (https://download.pytorch.org/libtorch/cu90/libtorch-shared-with-deps-latest.zip). +Add path to the extracted LibTorch in file `build.sbt`. -List(128, 10) -``` -Zero the gradient buffers of all parameters and backprop with random gradients. -```scala -net.zeroGrad() -out.backward(Variable(ns.randn(numSamples, numClasses))) -``` - -Before proceeding further, let’s recap all the classes you’ve seen so far. - -__Recap:__ -* `numsca.Tensor` - A multi-dimensional array. -* `autograd.Variable` - Wraps a Tensor and records the history of operations applied to it. - - Has (almost) the same API as a `Tensor`, with some additions like `backward()`. Also holds the gradient w.r.t. the tensor. - -* `nn.Module` - Neural network module. Convenient way of encapsulating parameters. -* `autograd.Function` - Implements forward and backward definitions of an autograd operation. - - Every `Variable` operation, creates at least a single `Function` node, - that connects to functions that created a `Variable` and encodes its history. - -__At this point, we covered:__ - -* Defining a neural network -* Processing inputs and calling backward - -__Still Left:__ - -* Computing the loss -* Updating the weights of the network - -### Loss function -A loss function takes the (output, target) pair of inputs, -and computes a value that estimates how far away the output is from the target. - -There are several different loss functions under the `scorch` package . -A common loss is: `scorch.softmaxLoss` which computes the softmax loss between the input and the target. - -For example: -```scala -val target = Variable(ns.randint(numClasses, Array(numSamples, 1))) -val output = net(input) -val loss = softmaxLoss(output, target) -println(loss) -``` -```text -data: 5.61 -``` - -Now, if you follow loss in the backward direction, -using its `.gradFn` attribute, you will see a graph of computations that looks like this: -```text -input -> linear -> relu -> linear - -> SoftmaxLoss - -> loss -``` -So, when we call `loss.backward()`, -the whole graph is differentiated w.r.t. the loss, -and all Variables in the graph will have their `.grad` Variable accumulated with the gradient. - -### Backprop - -To backpropagate the error all we have to do is to call `loss.backward()`. -You need to clear the existing gradients though, else gradients will be accumulated to existing gradients. - -Now we will call `loss.backward()`, and have a look at fc1's bias gradients before and after the backward. - -```scala -net.zeroGrad() -println("fc1.bias.grad before backward") -println(fc1.bias.grad) -loss.backward() -println("fc1.bias.grad after backward") - -``` -```text -fc1.bias.grad before backward -data: [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00] -fc1.bias.grad after backward -data: [0.07, 0.20, 0.21, -0.04, 0.16, 0.09, 0.34, -0.06, 0.17, -0.06, 0.02, -0.01, -0.07, 0.09, 0.12, -0.04, 0.19, 0.28, 0.06, 0.13] -``` - -Now, we have seen how to use loss functions. - -__The only thing left to learn is:__ - -* Updating the weights of the network - -### Update the weights - -The simplest update rule used in practice is the Stochastic Gradient Descent (SGD): - - `weight = weight - learningRate * gradient` - -We can implement this using simple scala code: - -```scala -net.parameters.foreach(p => p.data -= p.grad.data * learningRate) -``` - -However, as you use neural networks, you want to use various different update rules such as -SGD, Nesterov, Adam, etc. -To enable this, we built a small package: scorch.optim that implements these methods. Using it is very simple: -```scala -import scorch.optim.SGD - -// create an optimizer for updating the parameters -val optimizer = SGD(net.parameters, lr = 0.01) - -// in the training loop: - -optimizer.zeroGrad() // reset the gradients of the parameters -val output = net(input) // forward input through the network -val loss = softmaxLoss(output, target) // calculate the loss -loss.backward() // back propagate the derivatives -optimizer.step() // update the parameters with the gradients -``` - -## Wrap up -To wrap up, here is a complete example: - -```scala -import botkop.{numsca => ns} -import scorch.autograd.Variable -import scorch.optim.SGD -import scorch._ - -val numSamples = 128 -val numClasses = 10 -val nf1 = 40 -val nf2 = 20 - -// Define a simple neural net -case class Net() extends Module { - val fc1 = Linear(nf1, nf2) // an affine operation: y = Wx + b - val fc2 = Linear(nf2, numClasses) // another one - - // glue the layers with a relu non-linearity: fc1 -> relu -> fc2 - override def forward(x: Variable) = x ~> fc1 ~> relu ~> fc2 -} - -// instantiate -val net = Net() - -// create an optimizer for updating the parameters -val optimizer = SGD(net.parameters, lr = 0.01) - -// random target and input to train on -val target = Variable(ns.randint(numClasses, Array(numSamples, 1))) -val input = Variable(ns.randn(numSamples, nf1)) - -for (j <- 0 to 1000) { - - // reset the gradients of the parameters - optimizer.zeroGrad() - - // forward input through the network - val output = net(input) - - // calculate the loss - val loss = softmaxLoss(output, target) - - // print loss and accuracy - if (j % 100 == 0) { - val guessed = ns.argmax(output.data, axis = 1) - val accuracy = ns.sum(target.data == guessed) / numSamples - println(s"$j: loss: ${loss.data.squeeze()} accuracy: $accuracy") - } - - // back propagate the derivatives - loss.backward() - - // update the parameters with the gradients - optimizer.step() -} -``` - -## Contributors -Thanks to [Jasper](https://github.com/Jasper-M) for helping out with Scala type inference magic far beyond my capabilities. +` +JniBuildPlugin.autoImport.torchLibPath in jniBuild := "" +` -## Dependency -Add this to build.sbt: -```scala -libraryDependencies += "be.botkop" %% "scorch" % "0.1.0" -``` +Build .so lib: -## References -- [Deep Learning with PyTorch: A 60 Minute Blitz](http://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) -- [Backpropagation, Intuitions](http://cs231n.github.io/optimization-2/) -- [Automatic Differentiation in Machine Learning: a Survey](https://arxiv.org/pdf/1502.05767.pdf) -- [Automatic differentiation](http://www.pvv.ntnu.no/~berland/resources/autodiff-triallecture.pdf) -- [Derivative Calculator with step-by-step Explanations](http://calculus-calculator.com/derivative/) -- [Differentiation rules](https://en.wikipedia.org/wiki/Differentiation_rules) +>sbt jniBuild diff --git a/build.sbt b/build.sbt index 684422b..620d4ad 100644 --- a/build.sbt +++ b/build.sbt @@ -1,55 +1,27 @@ -resolvers += - "Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots" - -import Dependencies._ - -lazy val root = (project in file(".")).settings( - inThisBuild( - List( - organization := "be.botkop", - scalaVersion := "2.12.5", - version := "0.1.2-SNAPSHOT" - )), - name := "scorch", - libraryDependencies += numsca, - libraryDependencies += scalaTest % Test -) - -crossScalaVersions := Seq("2.11.12", "2.12.4") - -publishTo := { - val nexus = "https://oss.sonatype.org/" - if (isSnapshot.value) - Some("snapshots" at nexus + "content/repositories/snapshots") - else - Some("releases" at nexus + "service/local/staging/deploy/maven2") -} - -pomIncludeRepository := { _ => - false -} - -licenses := Seq( - "BSD-style" -> url("http://www.opensource.org/licenses/bsd-license.php")) - -homepage := Some(url("https://github.com/botkop")) - -scmInfo := Some( - ScmInfo( - url("https://github.com/botkop/scorch"), - "scm:git@github.com:botkop/scorch.git" - ) -) - -developers := List( - Developer( - id = "botkop", - name = "Koen Dejonghe", - email = "koen@botkop.be", - url = url("https://github.com/botkop") - ) -) - -publishMavenStyle := true -publishArtifact in Test := false -// skip in publish := true +import sbt._ +import sbt.Keys._ + + +version := "1.0" + +scalaVersion := "2.12.7" + + +// https://mvnrepository.com/artifact/org.bytedeco/javacpp +libraryDependencies += "org.bytedeco" % "javacpp" % "1.4.3" +libraryDependencies += "org.scala-lang" % "scala-reflect" % "2.12.7" + +enablePlugins(JniGeneratorPlugin, JniBuildPlugin) +JniBuildPlugin.autoImport.torchLibPath in jniBuild := "/home/nazar/libtorch" +//sourceDirectory in nativeCompile := sourceDirectory.value / "native" +//target in nativeCompile :=target.value / "native" / nativePlatform.value + + +libraryDependencies += "com.typesafe.scala-logging" %% "scala-logging" % "3.7.2" +libraryDependencies += "ch.qos.logback" % "logback-classic" % "1.2.3" + +lazy val scalaTest = "org.scalatest" %% "scalatest" % "3.0.3" + +libraryDependencies += scalaTest % Test + + diff --git a/project/Dependencies.scala b/project/Dependencies.scala deleted file mode 100644 index 012b7e8..0000000 --- a/project/Dependencies.scala +++ /dev/null @@ -1,6 +0,0 @@ -import sbt._ - -object Dependencies { - lazy val scalaTest = "org.scalatest" %% "scalatest" % "3.0.3" - lazy val numsca = "be.botkop" %% "numsca" % "0.1.5" -} diff --git a/project/JniBuildPlugin.scala b/project/JniBuildPlugin.scala new file mode 100644 index 0000000..3c96ee6 --- /dev/null +++ b/project/JniBuildPlugin.scala @@ -0,0 +1,48 @@ + +import sbt._ +import sbt.Keys._ + +import sys.process._ + + +object JniBuildPlugin extends AutoPlugin { + + override val trigger: PluginTrigger = noTrigger + + override val requires: Plugins = plugins.JvmPlugin + + object autoImport extends JniGeneratorKeys { + lazy val jniBuild = taskKey[Unit]("Builds so lib") + } + + import autoImport._ + + override lazy val projectSettings: Seq[Setting[_]] =Seq( + + targetGeneratorDir in jniBuild := sourceDirectory.value / "native" , + + targetLibName in jniBuild := "java_torch_lib", + + jniBuild := { + val directory = (targetGeneratorDir in jniBuild).value + val cmake_prefix = (torchLibPath in jniBuild).value + val log = streams.value.log + + log.info("Build to " + directory.getAbsolutePath) + val command = s"cmake -H$directory -B$directory -DCMAKE_PREFIX_PATH=$cmake_prefix" + log.info(command) + val exitCode = Process(command) ! log + if (exitCode != 0) sys.error(s"An error occurred while running cmake. Exit code: $exitCode.") + val command1 = s"make -C$directory" + log.info(command1) + val exitCode1 = Process(command1) ! log + if (exitCode1 != 0) sys.error(s"An error occurred while running make. Exit code: $exitCode1.") + }, + + jniBuild := jniBuild.dependsOn(jniGen).value, + compile := (compile in Compile).dependsOn(jniBuild).value, + + ) + + +} diff --git a/project/JniGeneratorPlugin.scala b/project/JniGeneratorPlugin.scala new file mode 100644 index 0000000..42b4143 --- /dev/null +++ b/project/JniGeneratorPlugin.scala @@ -0,0 +1,125 @@ + +import java.io.{File, FileInputStream} + +import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Opcodes} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import sbt._ +import sbt.Keys._ + +import sys.process._ + +trait JniGeneratorKeys { + + lazy val torchLibPath = settingKey[String]("Path to C++ torch library.") + + lazy val targetGeneratorDir = settingKey[File]("target directory to store generated cpp files.") + + lazy val targetLibName = settingKey[String]("target cpp file name.") + + lazy val builderClass = settingKey[String]("class name that generates cpp file.") + + lazy val jniGen = taskKey[Unit]("Generates cpp files") + + lazy val javahClasses: TaskKey[Set[String]] = taskKey[Set[String]]( + "Finds the fully qualified names of classes containing native declarations.") + +} + + +object JniGeneratorPlugin extends AutoPlugin { + + override val trigger: PluginTrigger = noTrigger + + override val requires: Plugins = plugins.JvmPlugin + + object autoImport extends JniGeneratorKeys + + import autoImport._ + + override lazy val projectSettings: Seq[Setting[_]] =Seq( + javahClasses in jniGen := { + import xsbti.compile._ + val compiled: CompileAnalysis = (compile in Compile).value + val classFiles: Set[File] = compiled.readStamps.getAllProductStamps.asScala.keySet.toSet + val nativeClasses = classFiles flatMap { file => findNativeClasses(file) } + nativeClasses + }, + + targetGeneratorDir in jniGen := sourceDirectory.value / "native" , + + targetLibName in jniGen := "java_torch_lib", + + builderClass in jniGen := "generate.Builder", + + jniGen := { + val directory = (targetGeneratorDir in jniGen).value + val builder = (builderClass in jniGen).value + val libName = (targetLibName in jniGen).value + // The full classpath cannot be used here since it also generates resources. In a project combining JniJavah and + // JniPackage, we would have a chicken-and-egg problem. + val classPath: String = ((dependencyClasspath in Compile).value.map(_.data) ++ { + Seq((classDirectory in Compile).value) + }).mkString(sys.props("path.separator")) + val classes = (javahClasses in jniGen).value + val log = streams.value.log + + if (classes.nonEmpty) { + log.info("Sources will be generated to " + directory.getAbsolutePath) + log.info("Generating header for " + classes.mkString(" ")) + val command = s"java -classpath $classPath $builder -d $directory -o $libName ${classes.mkString(" ")}" // " torch_scala.NativeLibraryConfig" }" + log.info(command) + val exitCode = Process(command) ! log + if (exitCode != 0) sys.error(s"An error occurred while running javah. Exit code: $exitCode.") + } + } + + ) + + private class NativeFinder extends ClassVisitor(Opcodes.ASM5) { + private var fullyQualifiedName: String = "" + + /** Classes found to contain at least one @native definition. */ + private val _nativeClasses = mutable.HashSet.empty[String] + + def nativeClasses: Set[String] = _nativeClasses.toSet + + override def visit( + version: Int, access: Int, name: String, signature: String, superName: String, + interfaces: Array[String]): Unit = { + fullyQualifiedName = name.replaceAll("/", ".") + } + + override def visitMethod( + access: Int, name: String, desc: String, signature: String, exceptions: Array[String]): MethodVisitor = { + val isNative = (access & Opcodes.ACC_NATIVE) != 0 + if (isNative) + _nativeClasses += fullyQualifiedName + // Return null, meaning that we do not want to visit the method further. + null + } + } + + /** Finds classes containing native implementations (i.e., `@native` definitions). + * + * @param javaFile Java file from which classes are being read. + * @return Set containing all the fully qualified names of classes that contain at least one member annotated with + * the `@native` annotation. + */ + def findNativeClasses(javaFile: File): Set[String] = { + var inputStream: FileInputStream = null + try { + inputStream = new FileInputStream(javaFile) + val reader = new ClassReader(inputStream) + val finder = new NativeFinder + reader.accept(finder, 0) + finder.nativeClasses + } finally { + if (inputStream != null) + inputStream.close() + } + } + + +} diff --git a/project/build.properties b/project/build.properties deleted file mode 100644 index c091b86..0000000 --- a/project/build.properties +++ /dev/null @@ -1 +0,0 @@ -sbt.version=0.13.16 diff --git a/project/plugins.sbt b/project/plugins.sbt index 8de15ce..905736e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1 +1,24 @@ -addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "2.0") \ No newline at end of file +/* Copyright 2017-18, Emmanouil Antonios Platanios. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +import sbt.Defaults.sbtPluginExtra + +logLevel := Level.Warn + +libraryDependencies ++= Seq( + "ch.qos.logback" % "logback-classic" % "1.2.3", + "org.ow2.asm" % "asm" % "6.2.1") + + diff --git a/src/main/java/generate/Builder.java b/src/main/java/generate/Builder.java new file mode 100644 index 0000000..b371327 --- /dev/null +++ b/src/main/java/generate/Builder.java @@ -0,0 +1,691 @@ +/* + * Copyright (C) 2011-2018 Samuel Audet + * + * Licensed either under the Apache License, Version 2.0, or (at your option) + * under the terms of the GNU General Public License as published by + * the Free Software Foundation (subject to the "Classpath" exception), + * either version 2, or any later version (collectively, the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.gnu.org/licenses/ + * http://www.gnu.org/software/classpath/license.html + * + * or as provided in the LICENSE.txt file that accompanied this code. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package generate; + +import org.bytedeco.javacpp.ClassProperties; +import org.bytedeco.javacpp.Loader; +import org.bytedeco.javacpp.tools.*; + +import java.io.*; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.file.*; +import java.nio.file.attribute.BasicFileAttributes; +import java.util.*; +import java.util.jar.JarOutputStream; +import java.util.zip.ZipEntry; + +/** + * The Builder is responsible for coordinating efforts between the Parser, the + * Generator, and the native compiler. It contains the main() method, and basically + * takes care of the tasks one would expect from a command line build tool, but + * can also be used programmatically by setting its properties and calling build(). + * + * @author Samuel Audet + */ +public class Builder { + + /** + * Calls {@link Parser#parse(File, String[], Class)} after creating an instance of the Class. + * + * @param classPath an array of paths to try to load header files from + * @param cls The class annotated with {@link org.bytedeco.javacpp.annotation.Properties} + * and implementing {@link InfoMapper} + * @return the target File produced + * @throws IOException on Java target file writing error + * @throws ParserException on C/C++ header file parsing error + */ + File parse(String[] classPath, Class cls) throws IOException, ParserException { + return new Parser(logger, properties, encoding, null).parse(outputDirectory, classPath, cls); + } + + /** + * Tries to find automatically include paths for {@code jni.h} and {@code jni_md.h}, + * as well as the link and library paths for the {@code jvm} library. + * + * @param properties the Properties containing the paths to update + * @param header to request support for exporting callbacks via generated header file + */ + void includeJavaPaths(ClassProperties properties, boolean header) { + if (properties.getProperty("platform", "").startsWith("android")) { + // Android includes its own jni.h file and doesn't have a jvm library + return; + } + String platform = Loader.getPlatform(); + final String jvmlink = properties.getProperty("platform.link.prefix", "") + + "jvm" + properties.getProperty("platform.link.suffix", ""); + final String jvmlib = properties.getProperty("platform.library.prefix", "") + + "jvm" + properties.getProperty("platform.library.suffix", ""); + final String[] jnipath = new String[2]; + final String[] jvmpath = new String[2]; + FilenameFilter filter = new FilenameFilter() { + @Override public boolean accept(File dir, String name) { + if (new File(dir, "jni.h").exists()) { + jnipath[0] = dir.getAbsolutePath(); + } + if (new File(dir, "jni_md.h").exists()) { + jnipath[1] = dir.getAbsolutePath(); + } + if (new File(dir, jvmlink).exists()) { + jvmpath[0] = dir.getAbsolutePath(); + } + if (new File(dir, jvmlib).exists()) { + jvmpath[1] = dir.getAbsolutePath(); + } + return new File(dir, name).isDirectory(); + } + }; + File javaHome; + try { + javaHome = new File(System.getProperty("java.home")).getParentFile().getCanonicalFile(); + } catch (IOException | NullPointerException e) { + logger.warn("Could not include header files from java.home:" + e); + return; + } + ArrayList dirs = new ArrayList(Arrays.asList(javaHome.listFiles(filter))); + while (!dirs.isEmpty()) { + File d = dirs.remove(dirs.size() - 1); + String dpath = d.getPath(); + File[] files = d.listFiles(filter); + if (dpath == null || files == null) { + continue; + } + for (File f : files) { + try { + f = f.getCanonicalFile(); + } catch (IOException e) { } + if (!dpath.startsWith(f.getPath())) { + dirs.add(f); + } + } + } + if (jnipath[0] != null && jnipath[0].equals(jnipath[1])) { + jnipath[1] = null; + } else if (jnipath[0] == null) { + String macpath = "/System/Library/Frameworks/JavaVM.framework/Headers/"; + if (new File(macpath).isDirectory()) { + jnipath[0] = macpath; + } + } + if (jvmpath[0] != null && jvmpath[0].equals(jvmpath[1])) { + jvmpath[1] = null; + } + properties.addAll("platform.includepath", jnipath); + if (platform.equals(properties.getProperty("platform", platform))) { + if (header) { + // We only need libjvm for callbacks exported with the header file + properties.get("platform.link").add(0, "jvm"); + properties.addAll("platform.linkpath", jvmpath); + } + if (platform.startsWith("macosx")) { + properties.addAll("platform.framework", "JavaVM"); + } + } + } + + + + /** + * Generates a C++ source file for classes, and compiles everything in + * one shared library when {@code compile == true}. + * + * @param classes the Class objects as input to Generator + * @param outputName the output name of the shared library + * @return the actual File generated, either the compiled library or its source + * @throws IOException + * @throws InterruptedException + */ + boolean generate(Class[] classes, String outputName, boolean first, boolean last) throws IOException, InterruptedException { + File outputPath = outputDirectory != null ? outputDirectory.getCanonicalFile() : null; + ClassProperties p = Loader.loadProperties(classes, properties, true); + String platform = properties.getProperty("platform"); + String extension = properties.getProperty("platform.extension"); + String sourcePrefix = outputPath != null ? outputPath.getPath() + File.separator : ""; + String sourceSuffix = p.getProperty("platform.source.suffix", ".cpp"); + String libraryPath = p.getProperty("platform.library.path", ""); + String libraryPrefix = p.getProperty("platform.library.prefix", "") ; + String librarySuffix = p.getProperty("platform.library.suffix", ""); + String[] sourcePrefixes = {sourcePrefix, sourcePrefix}; + if (outputPath == null) { + URI uri = null; + try { + String resourceName = '/' + classes[classes.length - 1].getName().replace('.', '/') + ".class"; + String resourceURL = classes[classes.length - 1].getResource(resourceName).toString(); + uri = new URI(resourceURL.substring(0, resourceURL.lastIndexOf('/') + 1)); + boolean isFile = "file".equals(uri.getScheme()); + File classPath = new File(classScanner.getClassLoader().getPaths()[0]).getCanonicalFile(); + // If our class is not a file, use first path of the user class loader as base for our output path + File packageDir = isFile ? new File(uri) + : new File(classPath, resourceName.substring(0, resourceName.lastIndexOf('/') + 1)); + // Output to the library path inside of the class path, if provided by the user + uri = new URI(resourceURL.substring(0, resourceURL.length() - resourceName.length() + 1)); + File targetDir = libraryPath.length() > 0 + ? (isFile ? new File(uri) : classPath) + : new File(packageDir, platform + (extension != null ? extension : "")); + outputPath = new File(targetDir, libraryPath); + sourcePrefix = packageDir.getPath() + File.separator; + // make sure jnijavacpp.cpp ends up in the same directory for all classes in different packages + sourcePrefixes = new String[] {classPath.getPath() + File.separator, sourcePrefix}; + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } catch (IllegalArgumentException e) { + throw new RuntimeException("URI: " + uri, e); + } + } + if (!outputPath.exists()) { + outputPath.mkdirs(); + } + Generator generator = new Generator(logger, properties, encoding); + String[] sourceFilenames = {sourcePrefixes[0] + "jnijavacpp" + sourceSuffix, + sourcePrefixes[1] + outputName + sourceSuffix}; + String[] headerFilenames = {null, header ? sourcePrefixes[1] + outputName + ".h" : null}; + String[] loadSuffixes = {"_jnijavacpp", null}; + String[] baseLoadSuffixes = {null, "_jnijavacpp"}; + String classPath = System.getProperty("java.class.path"); + for (String s : classScanner.getClassLoader().getPaths()) { + classPath += File.pathSeparator + s; + } + String[] classPaths = {null, classPath}; + Class[][] classesArray = {null, classes}; + + boolean generated = true; + for (int i = 0; i < sourceFilenames.length; i++) { + if (i == 0 && !first) { + continue; + } + logger.info("Generating " + sourceFilenames[i]); + if (!generator.generate(sourceFilenames[i], headerFilenames[i], + loadSuffixes[i], baseLoadSuffixes[i], classPaths[i], classesArray[i])) { + logger.info("Nothing generated for " + sourceFilenames[i]); + generated = false; + break; + } + } + + return generated; + + } + + + + /** + * Default constructor that simply initializes everything. + */ + public Builder() { + this(Logger.create(Builder.class)); + } + /** + * Constructor that simply initializes everything. + * @param logger where to send messages + */ + public Builder(Logger logger) { + this.logger = logger; + System.setProperty("org.bytedeco.javacpp.loadlibraries", "false"); + properties = Loader.loadProperties(); + classScanner = new ClassScanner(logger, new ArrayList(), + new UserClassLoader(Thread.currentThread().getContextClassLoader())); + compilerOptions = new ArrayList(); + } + + /** Logger where to send debug, info, warning, and error messages. */ + final Logger logger; + /** The name of the character encoding used for input files as well as output files. */ + String encoding = null; + /** The directory where the generated files and compiled shared libraries get written to. + * By default they are placed in the same directory as the {@code .class} file. */ + File outputDirectory = null; + /** The name of the output generated source file or shared library. This enables single- + * file output mode. By default, the top-level enclosing classes get one file each. */ + String outputName = null; + /** The name of the JAR file to create, if not {@code null}. */ + String jarPrefix = null; + /** If true, compiles the generated source file to a shared library and deletes source. */ + boolean compile = true; + /** If true, preserves the generated C++ JNI files after compilation */ + boolean deleteJniFiles = true; + /** If true, also generates C++ header files containing declarations of callback functions. */ + boolean header = false; + /** If true, also copies to the output directory dependent shared libraries (link and preload). */ + boolean copyLibs = false; + /** If true, also copies to the output directory resources listed in properties. */ + boolean copyResources = false; + /** Accumulates the various properties loaded from resources, files, command line options, etc. */ + Properties properties = null; + /** The instance of the {@link ClassScanner} that fills up a {@link Collection} of {@link Class} objects to process. */ + ClassScanner classScanner = null; + /** A system command for {@link ProcessBuilder} to execute for the build, instead of JavaCPP itself. */ + String[] buildCommand = null; + /** User specified working directory to execute build subprocesses under. */ + File workingDirectory = null; + /** User specified environment variables to pass to the native compiler. */ + Map environmentVariables = null; + /** Contains additional command line options from the user for the native compiler. */ + Collection compilerOptions = null; + + /** Splits argument with {@link File#pathSeparator} and appends result to paths of the {@link #classScanner}. */ + public Builder classPaths(String classPaths) { + classPaths(classPaths == null ? null : classPaths.split(File.pathSeparator)); + return this; + } + /** Appends argument to the paths of the {@link #classScanner}. */ + public Builder classPaths(String ... classPaths) { + classScanner.getClassLoader().addPaths(classPaths); + return this; + } + /** Sets the {@link #encoding} field to the argument. */ + public Builder encoding(String encoding) { + this.encoding = encoding; + return this; + } + /** Sets the {@link #outputDirectory} field to the argument. */ + public Builder outputDirectory(String outputDirectory) { + outputDirectory(outputDirectory == null ? null : new File(outputDirectory)); + return this; + } + /** Sets the {@link #outputDirectory} field to the argument. */ + public Builder outputDirectory(File outputDirectory) { + this.outputDirectory = outputDirectory; + return this; + } + /** Sets the {@link #compile} field to the argument. */ + public Builder compile(boolean compile) { + this.compile = compile; + return this; + } + /** Sets the {@link #deleteJniFiles} field to the argument. */ + public Builder deleteJniFiles(boolean deleteJniFiles) { + this.deleteJniFiles = deleteJniFiles; + return this; + } + /** Sets the {@link #header} field to the argument. */ + public Builder header(boolean header) { + this.header = header; + return this; + } + /** Sets the {@link #copyLibs} field to the argument. */ + public Builder copyLibs(boolean copyLibs) { + this.copyLibs = copyLibs; + return this; + } + /** Sets the {@link #copyResources} field to the argument. */ + public Builder copyResources(boolean copyResources) { + this.copyResources = copyResources; + return this; + } + /** Sets the {@link #outputName} field to the argument. */ + public Builder outputName(String outputName) { + this.outputName = outputName; + return this; + } + /** Sets the {@link #jarPrefix} field to the argument. */ + public Builder jarPrefix(String jarPrefix) { + this.jarPrefix = jarPrefix; + return this; + } + /** Sets the {@link #properties} field to the ones loaded from resources for the specified platform. */ + public Builder properties(String platform) { + if (platform != null) { + properties = Loader.loadProperties(platform, null); + } + return this; + } + /** Adds all the properties of the argument to the {@link #properties} field. */ + public Builder properties(Properties properties) { + if (properties != null) { + for (Map.Entry e : properties.entrySet()) { + property((String)e.getKey(), (String)e.getValue()); + } + } + return this; + } + /** Sets the {@link #properties} field to the ones loaded from the specified file. */ + public Builder propertyFile(String filename) throws IOException { + propertyFile(filename == null ? null : new File(filename)); + return this; + } + /** Sets the {@link #properties} field to the ones loaded from the specified file. */ + public Builder propertyFile(File propertyFile) throws IOException { + if (propertyFile == null) { + return this; + } + FileInputStream fis = new FileInputStream(propertyFile); + properties = new Properties(); + try { + properties.load(new InputStreamReader(fis)); + } catch (NoSuchMethodError e) { + properties.load(fis); + } + fis.close(); + return this; + } + /** Sets a property of the {@link #properties} field, in either "key=value" or "key:value" format. */ + public Builder property(String keyValue) { + int equalIndex = keyValue.indexOf('='); + if (equalIndex < 0) { + equalIndex = keyValue.indexOf(':'); + } + property(keyValue.substring(2, equalIndex), + keyValue.substring(equalIndex+1)); + return this; + } + /** Sets a key/value pair property of the {@link #properties} field. */ + public Builder property(String key, String value) { + if (key.length() > 0 && value.length() > 0) { + properties.put(key, value); + } + return this; + } + /** Requests the {@link #classScanner} to add a class or all classes from a package. + * A {@code null} argument indicates the unnamed package. */ + public Builder classesOrPackages(String ... classesOrPackages) throws IOException, ClassNotFoundException, NoClassDefFoundError { + if (classesOrPackages == null) { + classScanner.addPackage(null, true); + } else for (String s : classesOrPackages) { + classScanner.addClassOrPackage(s); + } + return this; + } + /** Sets the {@link #buildCommand} field to the argument. */ + public Builder buildCommand(String[] buildCommand) { + this.buildCommand = buildCommand; + return this; + } + /** Sets the {@link #workingDirectory} field to the argument. */ + public Builder workingDirectory(String workingDirectory) { + workingDirectory(workingDirectory == null ? null : new File(workingDirectory)); + return this; + } + /** Sets the {@link #workingDirectory} field to the argument. */ + public Builder workingDirectory(File workingDirectory) { + this.workingDirectory = workingDirectory; + return this; + } + /** Sets the {@link #environmentVariables} field to the argument. */ + public Builder environmentVariables(Map environmentVariables) { + this.environmentVariables = environmentVariables; + return this; + } + /** Appends arguments to the {@link #compilerOptions} field. */ + public Builder compilerOptions(String ... options) { + if (options != null) { + compilerOptions.addAll(Arrays.asList(options)); + } + return this; + } + + /** + * Starts the build process and returns an array of {@link File} produced. + * + * @return the array of File produced + * @throws IOException + * @throws InterruptedException + * @throws ParserException + */ + public boolean build() throws IOException, InterruptedException, ParserException { + if (buildCommand != null && buildCommand.length > 0) { + List command = Arrays.asList(buildCommand); + String platform = Loader.getPlatform(); + boolean windows = platform.startsWith("windows"); + for (int i = 0; i < command.size(); i++) { + String arg = command.get(i); + if (arg == null) { + arg = ""; + } + if (arg.trim().isEmpty() && windows) { + // seems to be the only way to pass empty arguments on Windows? + arg = "\"\""; + } + command.set(i, arg); + } + + String text = ""; + for (String s : command) { + boolean hasSpaces = s.indexOf(" ") > 0 || s.isEmpty(); + if (hasSpaces) { + text += windows ? "\"" : "'"; + } + text += s; + if (hasSpaces) { + text += windows ? "\"" : "'"; + } + text += " "; + } + logger.info(text); + + ProcessBuilder pb = new ProcessBuilder(command); + if (workingDirectory != null) { + pb.directory(workingDirectory); + } + if (environmentVariables != null) { + pb.environment().putAll(environmentVariables); + } + String paths = properties.getProperty("platform.buildpath", ""); + String links = properties.getProperty("platform.linkresource", ""); + String resources = properties.getProperty("platform.buildresource", ""); + String separator = properties.getProperty("platform.path.separator"); + if (paths.length() > 0 || resources.length() > 0) { + + // Get all native libraries for classes on the class path. + List libs = new ArrayList(); + ClassProperties libProperties = null; + for (Class c : classScanner.getClasses()) { + if (Loader.getEnclosingClass(c) != c) { + continue; + } + libProperties = Loader.loadProperties(c, properties, true); + if (!libProperties.isLoaded()) { + logger.warn("Could not load platform properties for " + c); + continue; + } + libs.addAll(libProperties.get("platform.preload")); + libs.addAll(libProperties.get("platform.link")); + } + if (libProperties == null) { + libProperties = new ClassProperties(properties); + } + + // Extract the required resources. + for (String s : resources.split(separator)) { + for (File f : Loader.cacheResources(s)) { + String path = f.getCanonicalPath(); + if (paths.length() > 0 && !paths.endsWith(separator)) { + paths += separator; + } + paths += path; + + // Also create symbolic links for native libraries found there. + List linkPaths = new ArrayList(); + for (String s2 : links.split(separator)) { + for (File f2 : Loader.cacheResources(s2)) { + String path2 = f2.getCanonicalPath(); + if (path2.startsWith(path) && !path2.equals(path)) { + linkPaths.add(path2); + } + } + } + File[] files = f.listFiles(); + if (files != null) { + for (File file : files) { + Loader.createLibraryLink(file.getAbsolutePath(), libProperties, null, + linkPaths.toArray(new String[linkPaths.size()])); + } + } + } + } + if (paths.length() > 0) { + pb.environment().put("BUILD_PATH", paths); + pb.environment().put("BUILD_PATH_SEPARATOR", separator); + } + } + int exitValue = pb.inheritIO().start().waitFor(); + if (exitValue != 0) { + throw new RuntimeException("Process exited with an error: " + exitValue); + } + return false; + } + + if (classScanner.getClasses().isEmpty()) { + return false; + } + + List outputFiles = new ArrayList(); + Map> map = new LinkedHashMap>(); + for (Class c : classScanner.getClasses()) { + if (Loader.getEnclosingClass(c) != c) { + continue; + } + ClassProperties p = Loader.loadProperties(c, properties, false); + if (!p.isLoaded()) { + logger.warn("Could not load platform properties for " + c); + continue; + } + try { + if (Arrays.asList(c.getInterfaces()).contains(BuildEnabled.class)) { + ((BuildEnabled)c.newInstance()).init(logger, properties, encoding); + } + } catch (ClassCastException | InstantiationException | IllegalAccessException e) { + // fail silently as if the interface wasn't implemented + } + String target = p.getProperty("target"); + if (target != null && !c.getName().equals(target)) { + File f = parse(classScanner.getClassLoader().getPaths(), c); + if (f != null) { + outputFiles.add(f); + } + continue; + } + String libraryName = outputName != null ? outputName : p.getProperty("platform.library", ""); + if (libraryName.length() == 0) { + continue; + } + LinkedHashSet classList = map.get(libraryName); + if (classList == null) { + map.put(libraryName, classList = new LinkedHashSet()); + } + classList.addAll(p.getEffectiveClasses()); + } + int count = 0; + for (String libraryName : map.keySet()) { + LinkedHashSet classSet = map.get(libraryName); + Class[] classArray = classSet.toArray(new Class[classSet.size()]); + boolean result = generate(classArray, libraryName, count == 0, count == map.size() - 1); + } + + + // reset the load flag to let users load compiled libraries + System.setProperty("org.bytedeco.javacpp.loadlibraries", "true"); + return true; + } + + /** + * Simply prints out to the display the command line usage. + */ + public static void printHelp() { + String version = Builder.class.getPackage().getImplementationVersion(); + if (version == null) { + version = "unknown"; + } + System.out.println( + "JavaCPP version " + version + "\n" + + "Copyright (C) 2011-2017 Samuel Audet \n" + + "Project site: https://github.com/bytedeco/javacpp"); + System.out.println(); + System.out.println("Usage: java -jar javacpp.jar [options] [class or package (suffixed with .* or .**)]"); + System.out.println(); + System.out.println("where options include:"); + System.out.println(); + System.out.println(" -classpath Load user classes from path"); + System.out.println(" -encoding Character encoding used for input and output files"); + System.out.println(" -d Output all generated files to directory"); + System.out.println(" -o Output everything in a file named after given name"); + System.out.println(" -nocompile Do not compile or delete the generated source files"); + System.out.println(" -nodelete Do not delete generated C++ JNI files after compilation"); + System.out.println(" -header Generate header file with declarations of callbacks functions"); + System.out.println(" -copylibs Copy to output directory dependent libraries (link and preload)"); + System.out.println(" -copyresources Copy to output directory resources listed in properties"); + System.out.println(" -jarprefix Also create a JAR file named \"-.jar\""); + System.out.println(" -properties Load all properties from resource"); + System.out.println(" -propertyfile Load all properties from file"); + System.out.println(" -D= Set property to value"); + System.out.println(" -Xcompiler