Skip to content

new branch #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,519 changes: 733 additions & 786 deletions brain_tumor_dataset_preparation.ipynb

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions conversion/conversion.pytotfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import onnx
from collections import OrderedDict
import tensorflow as tf
from torch.autograd import Variable
from onnx_tf.backend import prepare

class MLP(nn.Module):
def __init__(self, input_dims, n_hiddens, n_class):
super(MLP, self).__init__()
assert isinstance(input_dims, int), 'Please provide int for input_dims'
self.input_dims = input_dims
current_dims = input_dims
layers = OrderedDict()

if isinstance(n_hiddens, int):
n_hiddens = [n_hiddens]
else:
n_hiddens = list(n_hiddens)
for i, n_hidden in enumerate(n_hiddens):
layers['fc{}'.format(i+1)] = nn.Linear(current_dims, n_hidden)
layers['relu{}'.format(i+1)] = nn.ReLU()
layers['drop{}'.format(i+1)] = nn.Dropout(0.2)
current_dims = n_hidden
layers['out'] = nn.Linear(current_dims, n_class)

self.model= nn.Sequential(layers)
print(self.model)

def forward(self, input):
input = input.view(input.size(0), -1)
assert input.size(1) == self.input_dims
return self.model.forward(input)

print("%s" % sys.argv[1])
print("%s" % sys.argv[2])


# Load the trained model from file
trained_dict = torch.load(sys.argv[1], map_location={'cuda:0': 'cpu'})

trained_model = MLP(784, [256, 256], 10)
trained_model.load_state_dict(trained_dict)

if not os.path.exists("%s" % sys.argv[2]):
os.makedirs("%s" % sys.argv[2])

# Export the trained model to ONNX
dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model
torch.onnx.export(trained_model, dummy_input, "%s/mnist.onnx" % sys.argv[2])

# Load the ONNX file
model = onnx.load("%s/mnist.onnx" % sys.argv[2])

# Import the ONNX model to Tensorflow
tf_rep = prepare(model)

# Input nodes to the model
print('inputs:', tf_rep.inputs)

# Output nodes from the model
print('outputs:', tf_rep.outputs)

# All nodes in the model
print('tensor_dict:')
print(tf_rep.tensor_dict)

tf_rep.export_graph("%s/mnist.pb" % sys.argv[2])

converter = tf.lite.TFLiteConverter.from_frozen_graph(
"%s/mnist.pb" % sys.argv[2], tf_rep.inputs, tf_rep.outputs)
tflite_model = converter.convert()
open("%s/mnist.tflite" % sys.argv[2], "wb").write(tflite_model)
81 changes: 81 additions & 0 deletions endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
from io import BytesIO
from torch import argmax, load
from torch import device as DEVICE
from torch.cuda import is_available
from torch.nn import Sequential, Linear, SELU, Dropout, LogSigmoid
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.models import resnet50
from flask import Flask, jsonify, request

app = Flask(__name__)
LABELS = ['None', 'Meningioma', 'Glioma', 'Pituitary']

device = "cuda" if is_available() else "cpu"

resnet_model = resnet50(pretrained=True)

# Freeze model parameters
for param in resnet_model.parameters():
param.requires_grad = False

# Modify the fully connected layer
n_inputs = resnet_model.fc.in_features
resnet_model.fc = Sequential(Linear(n_inputs, 2048),
SELU(),
Dropout(p=0.4),
Linear(2048, 2048),
SELU(),
Dropout(p=0.4),
Linear(2048, 4),
LogSigmoid())

# Enable gradients for the fully connected layer
for param in resnet_model.fc.parameters():
param.requires_grad = True

resnet_model.to(device)

# Load the model weights
model_path = './models/bt_resnet50_model.pt'
resnet_model.load_state_dict(load(model_path, map_location=DEVICE(device)))
resnet_model.eval()

def preprocess_image(image_bytes):
transform = Compose([Resize((512, 512)), ToTensor()])
img = Image.open(BytesIO(image_bytes))
return transform(img).unsqueeze(0)

def get_prediction(image_bytes):
tensor = preprocess_image(image_bytes=image_bytes)
with torch.no_grad():
y_hat = resnet_model(tensor.to(device))
class_id = argmax(y_hat.data, dim=1)
return str(int(class_id)), LABELS[int(class_id)]

@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
if 'file' not in request.files:
return jsonify({'error': 'No file part'})

file = request.files['file']

if file.filename == '':
return jsonify({'error': 'No selected file'})

img_bytes = file.read()
class_id, class_name = get_prediction(img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})

@app.route('/')
def index():
return 'Welcome to the Brain Tumor Classification API!'

@app.route('/favicon.ico')
def favicon():
return '', 204

if __name__ == '__main__':
app.run(debug=True)
Binary file added static/images/1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/test1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/test11.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/test4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed static/images/test6.jpg
Binary file not shown.
Binary file added static/images/test7.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added static/images/test8.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions template/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<html>

<head>
<title>neuralBlack</title>
<title></title>
<link href="https://fonts.googleapis.com/css?family=Lobster+Two" rel="stylesheet">
<link href="https://fonts.googleapis.com/css?family=Roboto+Condensed:400,400i,700" rel="stylesheet">
<link href="https://fonts.googleapis.com/css?family=Concert+One" rel="stylesheet">
Expand Down Expand Up @@ -39,7 +39,7 @@
</body>

<footer>
<p style="left: 45.25%;">Made with <img src="{{ url_for('static', filename='images/love.png') }}" height=3% width=3% alt="love"/> by </p><p style="color: white;"><a href="https://github.com/aksh-ai">aksh-ai</a></p>
<p style="left: 45.25%;">Made with <img src="{{ url_for('static', filename='images/love.png') }}" height=3% width=3% alt="love"/> </p>
</footer>

</html>
4 changes: 2 additions & 2 deletions template/pred.html
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<html>

<head>
<title>neuralBlack</title>
<title></title>
<link href="https://fonts.googleapis.com/css?family=Lobster+Two" rel="stylesheet">
<link href="https://fonts.googleapis.com/css?family=Roboto+Condensed:400,400i,700" rel="stylesheet">
<link href="https://fonts.googleapis.com/css?family=Concert+One" rel="stylesheet">
Expand Down Expand Up @@ -35,7 +35,7 @@
</body>

<footer>
<p style="left: 45.25%;">Made with <img src="{{ url_for('static', filename='images/love.png') }}" height=3% width=3% alt="love"/> by </p><p style="color: white;"><a href="https://github.com/aksh-ai">aksh-ai</a></p>
<p style="left: 45.25%;">Made with <img src="{{ url_for('static', filename='images/love.png') }}" height=3% width=3% alt="love"/> </p><p style="color: white;">
</footer>

</html>
11 changes: 9 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from PIL import Image
from torchvision import transforms, models

device_name = "cuda:0:" if torch.cuda.is_available() else "cpu"
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device_name)

resnet_model = models.resnet50(pretrained=True)
Expand All @@ -30,7 +30,14 @@

resnet_model.to(device)

resnet_model.load_state_dict(torch.load('models\\bt_resnet50_model.pt'))
#resnet_model.load_state_dict(torch.load('models\\bt_resnet50_model.pt'))
state_dict = torch.load('models\\bt_resnet50_model.pt')

# Load the state dictionary into the model
resnet_model.load_state_dict(state_dict)

# Print the model to verify if the weights are loaded correctly
# print(resnet_model)

resnet_model.eval()

Expand Down
Loading