1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
import numpy as np
import urllib.request
import cv2
import network_executor_tflite
import cv_utils
def style_transfer_postprocess(preprocessed_frame: np.ndarray, image_shape: tuple):
"""
Resizes the output frame of style transfer network and changes the color back to original configuration
Args:
preprocessed_frame: A preprocessed frame after style transfer.
image_shape: Contains shape of the original frame before preprocessing.
Returns:
Resizing factor to scale coordinates according to image_shape.
"""
postprocessed_frame = np.squeeze(preprocessed_frame, axis=0)
# select original height and width from image_shape
frame_height = image_shape[0]
frame_width = image_shape[1]
postprocessed_frame = cv2.resize(postprocessed_frame, (frame_width, frame_height)).astype("float32") * 255
postprocessed_frame = cv2.cvtColor(postprocessed_frame, cv2.COLOR_RGB2BGR)
return postprocessed_frame
def create_stylized_detection(style_transfer_executor, style_transfer_class, frame: np.ndarray,
detections: list, resize_factor, labels: dict):
"""
Perform style transfer on a detected class in a frame
Args:
style_transfer_executor: The style transfer executor
style_transfer_class: The class detected to change its style
frame: The original captured frame from video source.
detections: A list of detected objects in the form [class, [box positions], confidence].
resize_factor: Resizing factor to scale box coordinates to output frame size.
labels: Dictionary of labels and colors keyed on the classification index.
"""
for detection in detections:
class_idx, box, confidence = [d for d in detection]
label = labels[class_idx][0]
if label.lower() == style_transfer_class.lower():
# Obtain frame size and resized bounding box positions
frame_height, frame_width = frame.shape[:2]
x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
# Ensure box stays within the frame
x_min, y_min = max(0, x_min), max(0, y_min)
x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
# Crop only the detected object
cropped_frame = cv_utils.crop_bounding_box_object(frame, x_min, y_min, x_max, y_max)
# Run style_transfer on preprocessed_frame
stylized_frame = style_transfer_executor.run_style_transfer(cropped_frame)
# Paste stylized_frame on the original frame in the correct place
frame[int(y_min)+1:int(y_max), int(x_min)+1:int(x_max)] = stylized_frame
return frame
class StyleTransfer:
def __init__(self, style_predict_model_path: str, style_transfer_model_path: str,
style_image: np.ndarray, backends: list, delegate_path: str):
"""
Creates an inference executor for style predict network, style transfer network,
list of backends and a style image.
Args:
style_predict_model_path: model which is used to create a style bottleneck
style_transfer_model_path: model which is used to create stylized frames
style_image: an image to create the style bottleneck
backends: List of backends to optimize network.
delegate_path: tflite delegate file path (.so).
"""
self.style_predict_executor = network_executor_tflite.TFLiteNetworkExecutor(style_predict_model_path, backends,
delegate_path)
self.style_transfer_executor = network_executor_tflite.TFLiteNetworkExecutor(style_transfer_model_path,
backends,
delegate_path)
self.style_bottleneck = self.run_style_predict(style_image)
def get_style_predict_executor_shape(self):
"""
Get the input shape of the initiated network.
Returns:
tuple: The Shape of the network input.
"""
return self.style_predict_executor.get_shape()
# Function to run create a style_bottleneck using preprocessed style image.
def run_style_predict(self, style_image):
"""
Creates bottleneck tensor for a given style image.
Args:
style_image: an image to create the style bottleneck
Returns:
style bottleneck tensor
"""
# The style image has to be preprocessed to (1, 256, 256, 3)
preprocessed_style_image = cv_utils.preprocess(style_image, self.style_predict_executor.get_data_type(),
self.style_predict_executor.get_shape(), True, keep_aspect_ratio=False)
# output[0] is the style bottleneck tensor
style_bottleneck = self.style_predict_executor.run([preprocessed_style_image])[0]
return style_bottleneck
# Run style transform on preprocessed style image
def run_style_transfer(self, content_image):
"""
Runs inference for given content_image and style bottleneck to create a stylized image.
Args:
content_image:a content image to stylize
"""
# The content image has to be preprocessed to (1, 384, 384, 3)
preprocessed_style_image = cv_utils.preprocess(content_image, np.float32,
self.style_transfer_executor.get_shape(), True, keep_aspect_ratio=False)
# Transform content image. output[0] is the stylized image
stylized_image = self.style_transfer_executor.run([preprocessed_style_image, self.style_bottleneck])[0]
post_stylized_image = style_transfer_postprocess(stylized_image, content_image.shape)
return post_stylized_image
|