import sys

import cv2

import torch

import numpy as np

from PIL import Image, ImageDraw

from torchvision.transforms import functional as F

from PyQt5.QtCore import pyqtSignal, pyqtSlot, QThread

from PyQt5.QtGui import QPixmap, QImage

from PyQt5.QtWidgets import QWidget, QApplication, QLabel, QVBoxLayout

import torchvision

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def get_model_instance_segmentation(num_classes):

# load an instance segmentation model pre-trained on COCO

model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

# get number of input features for the classifier

in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# now get the number of input features for the mask classifier

in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels

hidden_layer = 256

# and replace the mask predictor with a new one

model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,

hidden_layer,

num_classes)

return model

class VideoThread(QThread):

pixmap_signal = pyqtSignal(np.ndarray)

def __init__(self, model_path: str):

super(VideoThread, self).__init__()

self._is_running = True

# модель была сохранена как torch.save(model.state_dict(), 'segmentation_model.pt')

self.model = get_model_instance_segmentation(2)

self.model.load_state_dict(torch.load(model_path))

self.model.eval()

def run(self):

capture = cv2.VideoCapture(0)

while self._is_running:

ret, img = capture.read()

if ret:

img = self.convert_to_pil(img)

prediction = self.predict(img)

self.pixmap_signal.emit(self.draw_boxes(img, prediction))

capture.release()

def convert_to_pil(self, img: np.ndarray) -> Image:

return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

def predict(self, img: Image):

img = F.pil_to_tensor(img)

img = F.convert_image_dtype(img, dtype=torch.float)

with torch.no_grad():

prediction = self.model([img])

return prediction

def draw_boxes(self, img: Image, prediction):

state_dict = prediction[0]

draw = ImageDraw.Draw(img)

for i in range(len(state_dict['labels'])):

draw.rectangle(state_dict['boxes'][i].cpu().numpy(), outline="#FF0000")

return np.array(img)

def stop(self):

self._is_running = False

self.wait()

class VideoWidget(QWidget):

def __init__(self):

super(VideoWidget, self).__init__()

self.display_width = 640

self.display_height = 480

self.image_label = QLabel(self)

layout = QVBoxLayout()

layout.addWidget(self.image_label)

self.setLayout(layout)

self.video_thread = VideoThread('./segmentation_model.pt')

self.video_thread.pixmap_signal.connect(self.update_image)

self.video_thread.start()

def closeEvent(self, event):

self.video_thread.stop()

event.accept()

def np_to_pixmap(self, img: np.ndarray) -> QPixmap:

rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

h, w, ch = rgb_img.shape

bytes_per_line = ch * w

qimage = QImage(rgb_img.data, w, h, bytes_per_line, QImage.Format_RGB888)

return QPixmap.fromImage(

qimage.scaled(self.display_width, self.display_height)

)

@pyqtSlot(np.ndarray)

def update_image(self, img: np.ndarray):

pixmap = self.np_to_pixmap(img)

self.image_label.setPixmap(pixmap)

if __name__ == '__main__':

app = QApplication(sys.argv)

window = VideoWidget()

window.show()

sys.exit(app.exec_())

1 комментарий

Комментарий недоступен