Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions PyQt/3dView.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import sys
import pyqtgraph.opengl as gl
from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QPushButton
import SimpleITK as sitk
import numpy as np

# Create a PyQt application and main window
app = QApplication(sys.argv)
main_window = QMainWindow()
main_window.setWindowTitle("3D NIfTI Viewer")
central_widget = QWidget()
main_window.setCentralWidget(central_widget)
layout = QVBoxLayout()
central_widget.setLayout(layout)

# Create a PyQtGraph OpenGLWidget to display the 3D image
view = gl.GLViewWidget()
layout.addWidget(view)

# Define a function to load and display the NIfTI image
def load_nifti_and_display():
# Replace 'your_image.nii.gz' with the path to your NIfTI file
nifti_file = "Demo 3d Data/BraTS2021_00000_0000.nii.gz"

# Load the NIfTI image using SimpleITK
sitk_image = sitk.ReadImage(nifti_file)

# Convert the SimpleITK image to a NumPy array
data = sitk.GetArrayFromImage(sitk_image)

# Swap axes to match the expected shape by GLVolumeItem
# data = np.swapaxes(data, 0, 2) # Swap the first and third axes

# Normalize the data to [0, 1]
min_val = np.min(data)
max_val = np.max(data)
normalized_data = (data - min_val) / (max_val - min_val)

# Create a volume item and add it to the view
volume = gl.GLVolumeItem(normalized_data, sliceDensity=2, smooth=True)
# volume.setLevels(min_val, max_val) # Set levels for volume rendering
view.addItem(volume)

# Create a button to trigger the loading and display of the NIfTI image
load_button = QPushButton("Load NIfTI Image")
load_button.clicked.connect(load_nifti_and_display)
layout.addWidget(load_button)

# Show the main window
main_window.show()
sys.exit(app.exec_())
160 changes: 160 additions & 0 deletions PyQt/Add Mas to Image with cv2 .ipynb

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions PyQt/AddMask2Image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import cv2
import numpy as np

# Load your main image
main_image = cv2.imread('main_image.jpg')

# Generate the mask using your segmentation model (replace this with your actual code)
# Assuming you have the mask as a NumPy array with the same shape as the main image
# mask = your_segmentation_model(main_image)

# Make sure the mask has the same number of channels as the main image (3 for RGB)
if len(mask.shape) == 2:
mask = cv2.merge([mask] * 3)

# Overlay the mask on the main image
alpha = 0.5 # You can adjust the alpha value for transparency
result = cv2.addWeighted(main_image, 1 - alpha, mask, alpha, 0)

# Display the result
cv2.imshow('Segmentation Result', result)
cv2.waitKey(0)
cv2.destroyAllWindows()

# Save the result if needed
cv2.imwrite('result_image.jpg', result)
148 changes: 148 additions & 0 deletions PyQt/BasePyQT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import numpy as np
from PIL import Image
import nibabel as nib
import SimpleITK as sitk
from PyQt5.QtWidgets import QApplication, QMainWindow, QWidget, QVBoxLayout, QSlider, QPushButton, QFileDialog
from PyQt5.QtGui import QImage, QPixmap, QPainter
from PyQt5.QtCore import Qt
from MainWindow import MainWindowSegment


class SlicerWidget(QWidget):
def __init__(self, data_path, sitk_image):
super().__init__()
self.data_path = data_path
self.nibabel_data = nib.load(self.data_path).get_fdata()
self.sitk_image = sitk.GetArrayViewFromImage(sitk_image)
self.current_view = 'sagittal'
self.slice_index = max(self.nibabel_data.shape)// 2
self.image_path = None
self.slice_data_nibabel = None
self.slice_data = None
self.max_shape_size = max(self.nibabel_data.shape)
self.sitk_image = self.pad_to_specific_shape(self.sitk_image, (
max(self.nibabel_data.shape), max(self.nibabel_data.shape), max(self.nibabel_data.shape)))
self.nibabel_data = self.pad_to_specific_shape(self.nibabel_data, (
max(self.nibabel_data.shape), max(self.nibabel_data.shape), max(self.nibabel_data.shape)))
print(self.sitk_image.shape)
print(self.nibabel_data.shape)

def pad_to_specific_shape(self, input_array, target_shape, pad_value=0):
"""
Pad a NumPy array to a specific shape.

Parameters:
input_array (numpy.ndarray): The input array to be padded.
target_shape (tuple): The desired shape (tuple of integers) of the padded array.
pad_value (float or int, optional): The value used for padding. Default is 0.

Returns:
numpy.ndarray: The padded array with the specified shape.
"""
# Ensure the input array and target shape have the same number of dimensions
if len(input_array.shape) != len(target_shape):
raise ValueError("Input array and target shape must have the same number of dimensions.")

# Calculate the padding required for each dimension
pad_width = [(0, max(0, target_shape[i] - input_array.shape[i])) for i in range(len(target_shape))]

# Pad the input array
padded_array = np.pad(input_array, pad_width, mode='constant', constant_values=pad_value)

return padded_array

def paintEvent(self, event):
print(self.slice_index)
print(self.current_view)
print('---------------------------------------')
painter = QPainter(self)

if self.current_view == 'sagittal':
self.slice_data = self.sitk_image[:, :, self.slice_index]
self.slice_data_nibabel = self.nibabel_data[:, :, self.slice_index]
if self.current_view == 'coronal':
self.slice_data = self.sitk_image[:, self.slice_index, :]
self.slice_data_nibabel = self.nibabel_data[:, self.slice_index, :]
if self.current_view == 'axial':
self.slice_data = self.sitk_image[self.slice_index, :, :]
self.slice_data_nibabel = self.nibabel_data[self.slice_index, :, :]

slice_data = ((self.slice_data - self.slice_data.min()) / (
self.slice_data.max() - self.slice_data.min()) * 255).astype('uint8')
height, width = slice_data.shape
bytes_per_line = width
image = QImage(slice_data.data, width, height, bytes_per_line, QImage.Format_Grayscale8)
pixmap = QPixmap.fromImage(image)
painter.drawPixmap(0, 0, self.width(), self.height(), pixmap)

def set_current_view(self, view):
self.current_view = view
self.slice_index = self.max_shape_size // 2
self.update()

def set_slice_index(self, index):
self.slice_index = index
self.update()

def save_current_view_as_jpg(self):
print(self.slice_index)
print(self.current_view)
print('********************************')
options = QFileDialog.Options()
options |= QFileDialog.ReadOnly
file_path, _ = QFileDialog.getSaveFileName(self, f"Save {self.current_view.capitalize()} View as JPG", "",
"JPEG Image Files (*.jpg);;All Files (*)", options=options)
rescaled = (255.0 / self.slice_data.max() * (
self.slice_data - self.slice_data.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im.save(file_path)
self.image_path = file_path


class MainWindow(QMainWindow):
def __init__(self, data_path):
super().__init__()
self.data_path = data_path
self.setWindowTitle("3D Slicer")
self.setGeometry(100, 100, 800, 600)
sitk_image = sitk.ReadImage(self.data_path)
self.slicer_widget = SlicerWidget(self.data_path, sitk_image)

self.scrollbar = QSlider(Qt.Horizontal)
self.scrollbar.setMaximum(sitk_image.GetSize()[0] - 1)
self.scrollbar.valueChanged.connect(self.slicer_widget.set_slice_index)

self.save_button = QPushButton("Save as JPG")
self.save_button.clicked.connect(self.slicer_widget.save_current_view_as_jpg)

self.view_buttons = {
'Sagittal': 'sagittal',
'Coronal': 'coronal',
'Axial': 'axial',
}

self.start_button = QPushButton("Start Segment")
self.start_button.clicked.connect(self.close_window)

for button_text, view in self.view_buttons.items():
button = QPushButton(button_text)
button.clicked.connect(lambda _, view=view: self.slicer_widget.set_current_view(view))
self.view_buttons[button_text] = button

layout = QVBoxLayout()
layout.addWidget(self.slicer_widget)
layout.addWidget(self.scrollbar)
layout.addWidget(self.save_button)
for button_text, button in self.view_buttons.items():
layout.addWidget(button)
layout.addWidget(self.start_button)

central_widget = QWidget()
central_widget.setLayout(layout)
self.setCentralWidget(central_widget)

def close_window(self):
self.close()

def get_image_path(self):
return self.slicer_widget.image_path
Binary file added PyQt/Demo 3d Data/BraTS2021_00000_0000.nii.gz
Binary file not shown.
Binary file added PyQt/Demo 3d Data/BraTS2021_00000_0001.nii.gz
Binary file not shown.
61 changes: 61 additions & 0 deletions PyQt/Inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from segment_anything import sam_model_registry
from segment_anything.predictor_sammed import SammedPredictor
from argparse import Namespace

class Inference:
def __init__(self,image_path):
self.args = Namespace()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.args.image_size = 256
self.args.encoder_adapter = True
self.args.sam_checkpoint = ("../Pretrain-Models/sam-med2d_b.pth")
self.model = None
self.predictor = None
self.load_model()
self.image = cv2.imread(image_path)
self.set_image()

def load_model(self):
self.model = sam_model_registry["vit_b"](self.args).to(self.device)
self.predictor = SammedPredictor(self.model)

def set_image(self):
self.predictor.set_image(self.image)

def show_mask(self, mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)

def show_points(self, coords, labels, ax, marker_size=100):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='.', s=marker_size, edgecolor='white',
linewidth=0.5)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='.', s=marker_size, edgecolor='white',
linewidth=0.5)

def show_box(self, box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

def creat_mask(self, points, labels):
masks, scores, logits = self.predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True,
)
return masks, scores, logits

if __name__=="__main__":
c = Inference()
Loading