In this blog, we will use DeepLabv3+ architecture to build our person segmentation pipeline entirely from scratch. What We are going to build.

DeepLabv3+ Architecture:

The DeepLabv3 paper was introduced in “Rethinking Atrous Convolution for Semantic Image Segmentation”. After DeepLabv1 and DeepLabv2 are invented, authors tried to RETHINK or restructure the DeepLab architecture and finally come up with a more enhanced DeepLabv3.

https://miro.medium.com/max/720/1*Llh9dQ1ZMBqPMOJSf7WaBQ.webp

The DeepLabv3+ was introduced in “Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation” paper. It combines Atrous Spatial Pyramid Pooling(ASSP) from DeepLabv1(a) and Encoder Decoder Architecture from DeepLabv2(b).

Architecture-of-DeepLabV3-with-backbone-network.png

Atrous(Dilated) Convolution:

              fig 1: 3x3 Atrous(dilated) Convolution in action

        fig 1: 3x3 Atrous(dilated) Convolution in action

Dilated convolutions introduce another parameter to the convolution layers called dilation rate ‘r’. The dilation factor controls the spacing between the kernel points. The convolution performed this way is also known as the à trous algorithm. By controlling the rate parameter, we can arbitrarily control the receptive fields of the convolution layer. The receptive field is defined as the size of the region of the input feature map that produces each output element. This allows the convolution filter to look at larger areas of the input(receptive field) without decreasing the spatial resolution or increasing kernel size.

                Fig. 1.2: Standard vs Dilated Kernel

            Fig. 1.2: Standard vs Dilated Kernel

Atrous convolution is akin to the standard convolution except that the weights of an atrous convolution kernel are spaced r locations apart, i.e., the kernel of dilated convolution layers is sparse.

The Convolutions and max-pooling used in deep convolutions and the max-pooling layer have disadvantages. At each step, the spatial resolution of the feature map is halved. Implanting or up-sampling the original feature map onto the original images results in sparse feature extraction.

https://miro.medium.com/max/720/1*dxK0C3WBBqk_eF0k8KRScQ.png

The Atrous convolution allows the convolution filter to look at larger areas of input field without decreasing the spatial resolution or increasing the kernel size.

https://miro.medium.com/max/720/1*zKvtCFhcHpMCQhhjanq12w.png

Let x be the input feature map, y be the output and w be the filter, then atrous convolution for each location i on the output y is :

https://miro.medium.com/max/608/1*Gm3S_I_8A4QWIgqmNSDjQQ.png

Where r corresponds to the dilation rate. Here, by adjusting r we can control the filter’

s field of view.

Atrous Convolution Block in pytorch:

class Atrous_Convolution(nn.Module):
    """
		Compute Atrous/Dilated Convolution.
    """

    def __init__(
            self, input_channels, kernel_size, pad, dilation_rate,
            output_channels=256):
        super(Atrous_Convolution, self).__init__()

        self.conv = nn.Conv2d(in_channels=input_channels,
                              out_channels=output_channels,
                              kernel_size=kernel_size, padding=pad,
                              dilation=dilation_rate, bias=False)

        self.batchnorm = nn.BatchNorm2d(output_channels)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):

        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x

Encoder:

The Deeplabv3+ uses the Atrous Spatial Pyramid Pooling module, which probes convolution features at multiple scales by applying atrous convolution at different scales by applying atrous convolution with different rates with the image level features.

Atrous Spatial Pyramid Pooling (ASPP):

In ASPP, parallel atrous convolution with different rates is applied in the input feature map and fused together. The ASSP enables to encode of multi-scale contextual information, as objects of the same class can have different scales in the image.

In ASPP layer, one 1x1 convolution and three 3x3 convolutions with different rates (3, 6, 18) are applied. Also, an image pooling layer is applied for the global context. All filter layers have 256 filters with batch normalization. All the resulting filters from all the branches are then concatenated and passed through 1x1 convolution which generates the final logits.

https://miro.medium.com/max/720/1*_8p_KTPr5N0HSeIKV35G_g.png

Encoder block in pytorch:

class ASSP(nn.Module):
    """
			Encoder of DeepLabv3+.
    """

    def __init__(self, in_channles, out_channles):
        """Atrous Spatial Pyramid pooling layer
        Args:
            in_channles (int): No of input channel for Atrous_Convolution.
            out_channles (int): No of output channel for Atrous_Convolution.
        """
        super(ASSP, self).__init__()
        self.conv_1x1 = Atrous_Convolution(
            input_channels=in_channles, output_channels=out_channles,
            kernel_size=1, pad=0, dilation_rate=1)

        self.conv_6x6 = Atrous_Convolution(
            input_channels=in_channles, output_channels=out_channles,
            kernel_size=3, pad=6, dilation_rate=6)

        self.conv_12x12 = Atrous_Convolution(
            input_channels=in_channles, output_channels=out_channles,
            kernel_size=3, pad=12, dilation_rate=12)

        self.conv_18x18 = Atrous_Convolution(
            input_channels=in_channles, output_channels=out_channles,
            kernel_size=3, pad=18, dilation_rate=18)

        self.image_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(
                in_channels=in_channles, out_channels=out_channles,
                kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True))

        self.final_conv = Atrous_Convolution(
            input_channels=out_channles * 5, output_channels=out_channles,
            kernel_size=1, pad=0, dilation_rate=1)

    def forward(self, x):
        x_1x1 = self.conv_1x1(x)
        x_6x6 = self.conv_6x6(x)
        x_12x12 = self.conv_12x12(x)
        x_18x18 = self.conv_18x18(x)
        img_pool_opt = self.image_pool(x)
        img_pool_opt = F.interpolate(
            img_pool_opt, size=x_18x18.size()[2:],
            mode='bilinear', align_corners=True)
				# concatination of all features
        concat = torch.cat(
            (x_1x1, x_6x6, x_12x12, x_18x18, img_pool_opt),
            dim=1)
        x_final_conv = self.final_conv(concat)
        return x_final_conv

Decoder:

The encoder features are bi-linearly up-sampled by a factor of 4 and then concatenated with corresponding low-level features. 1X1 convolution is applied before concatenation so that the number of channels can be reduced. This is because the low-level features usually contain a large number of channels which may outweigh the importance of the rich encoder features. After concatenation, we apply 3X3 convolution to refine the features. The refined features are followed by another simple bi-linear up-sampling by a factor of 4.

Wrapping up the architecture:

For the backbone network we will be using the ResNet50:

class ResNet_50(nn.Module):
    def __init__(self, output_layer=None):
        super(ResNet_50, self).__init__()
        self.pretrained = models.resnet50(pretrained=True)
        self.output_layer = output_layer
        self.layers = list(self.pretrained._modules.keys())
        self.layer_count = 0
        for l in self.layers:
            if l != self.output_layer:
                self.layer_count += 1
            else:
                break
        for i in range(1, len(self.layers)-self.layer_count):
            self.dummy_var = self.pretrained._modules.pop(self.layers[-i])
        self.net = nn.Sequential(self.pretrained._modules)
        self.pretrained = None

    def forward(self, x):
        x = self.net(x)
        return x
class Deeplabv3Plus(nn.Module):
    def __init__(self, num_classes):

        super(Deeplabv3Plus, self).__init__()

        self.backbone = ResNet_50(output_layer='layer3')

        self.low_level_features = ResNet_50(output_layer='layer1')

        self.assp = ASSP(in_channles=1024, out_channles=256)

        self.conv1x1 = Atrous_Convolution(
            input_channels=256, output_channels=48, kernel_size=1,
            dilation_rate=1, pad=0)

        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(304, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.classifer = nn.Conv2d(256, num_classes, 1)

    def forward(self, x):

        x_backbone = self.backbone(x)
        x_low_level = self.low_level_features(x)
        x_assp = self.assp(x_backbone)
        x_assp_upsampled = F.interpolate(
            x_assp, scale_factor=(4, 4),
            mode='bilinear', align_corners=True)
        x_conv1x1 = self.conv1x1(x_low_level)
        x_cat = torch.cat([x_conv1x1, x_assp_upsampled], dim=1)
        x_3x3 = self.conv_3x3(x_cat)
        x_3x3_upscaled = F.interpolate(
            x_3x3, scale_factor=(4, 4),
            mode='bilinear', align_corners=True)
        x_out = self.classifer(x_3x3_upscaled)
        return x_out

Using Deeplabv3+ for Person Segmentation:

For the dateset, we will be using person segmentation dataset. It consist of images and masks of 640X640 dimension with some augmentation like channel shuffle, rotation and Horizontal-flip etc.

The dateset can be downloaded from :

Person segmentation

Datasets and Dataloader:

import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader

TRAIN_IMG_DIR = "data/new_data/train/image"
TRAIN_MASK_DIR = "data/new_data/train/mask"
VAL_IMG_DIR = "data/new_data/test/image"
VAL_MASK_DIR = "data/new_data/test/mask"

class PersonSegmentData(Dataset):

    def __init__(self, image_dir, mask_dir, transform=None) -> None:
        super(PersonSegmentData, self).__init__()
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(
            self.mask_dir, self.images[index].replace(".jpg", ".png"))

        image = np.array(Image.open(image_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"),
                        dtype=np.float32)  # l -> grayscale
        mask[mask == 255.0] = 1.0

        if self.transform is not None:
            augemantations = self.transform(image=image, mask=mask)
            image = augemantations['image']
            mask = augemantations['mask']
        return image, mask

def get_data_loaders(
        train_dir, train_mask_dir, val_dir, val_maskdir, batch_size,
        train_transform, val_transform, num_workers=4, pin_memory=True):

    train_ds = PersonSegmentData(
        image_dir=train_dir, mask_dir=train_mask_dir,
        transform=train_transform)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = PersonSegmentData(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

Hyper-parameters , train and validation loader

import os 
import torch
import numpy as np
from PIL import Image
import albumentations as A
from torch.utils.data import Dataset , DataLoader
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_EPOCHS = 10
NUM_WORKERS = 4
PIN_MEMORY = True
LOAD_MODEL = False

# train transform 
train_transform = A.Compose(
    [
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

# validation transfroms
val_transforms = A.Compose(
    [
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

train_loader , val_loader = get_data_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    BATCH_SIZE,
    train_transform, 
    val_transforms
)

Visualizing the train_loader/val_loader:

import matplotlib.pyplot as plt
import numpy as np

def show_transformed(train_loader):
    batch = next(iter(train_loader))
    images, labels = batch

    for img , mask in zip(images,labels):
        plt.figure(figsize=(11,11))

        plt.subplot(1,2,1)
        plt.imshow(np.transpose(img , (1,2,0)))

        plt.subplot(1,2,2)
        plt.imshow(mask.reshape(mask.shape[0],mask.shape[1] , 1))

    
    
show_transformed(val_loader)

Output

Output

Output

Loss and Metrics:

Dice loss:

The Dice coefficient, or Dice-Sørensen coefficient, is a common metric for pixel segmentation that can also be modified to act as a loss function.

https://i.stack.imgur.com/OsH4y.png

We prefer Dice Loss instead of Cross Entropy because most of the semantic segmentation comes from an unbalanced dateset. So, how most of the semantic segmentation datasets are unbalanced? Suppose you have an image of a cat and you want to segment your image as cat(foreground) vs not-cat(background). In most of these image cases you will likely see most of the pixel in an image that is not-cat (background). And on an average you may find that 70-90% of the pixel in the image corresponds to background and only 10-30% on the foreground. So, if we use CE loss the algorithm may predict most of the pixel as background even when they are not and still get low errors. But in case of Dice Loss ( function of Intersection and Union over foreground pixel ) if the model predicts all the pixel as background the intersection would be 0 this would give rise to error=1 ( maximum error as Dice loss is between 0 and 1). Hence, Dice loss gives low error as it focuses on maximizing the intersection area over foreground while minimizing the Union over foreground. For our task we will be using the BCEDice lossThis loss combines Dice loss with the standard binary cross-entropy (BCE) loss that is generally the default for segmentation models. Combining the two methods allows for some diversity in the loss, while benefiting from the stability of BCE.

class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()
        self.bce_losss = nn.BCEWithLogitsLoss()

    def forward(self, inputs, targets, smooth=1):

        BCE = self.bce_losss(inputs, targets)

        inputs = torch.sigmoid(inputs)

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(
            inputs.sum() + targets.sum() + smooth)
        Dice_BCE = BCE + dice_loss

        return Dice_BCE

IOU:

Here we will be using the intersection over union as a performance metric for each batch in the training datasets. It is used to detect if the image is segmented right and how perfectly the image is segmented.

The IoU of a proposed set of object pixels and a set of true object pixels is calculated as:

https://miro.medium.com/max/720/1*ijgBc4dCoyQuzCZhw0j5bw.png

class IOU(nn.Module):

    def __init__(self, weight=None, size_average=True):
        super(IOU, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        # comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        # intersection is equivalent to True Positive count
        # union is the mutually inclusive area of all labels & predictions
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection

        IoU = (intersection + smooth)/(union + smooth)

        return IoU

The Training loop:

def save_checkpoint(state, filename="resize.pth.tar"):
		"""
			saves checkpoint for each epoch
		"""
    print("=> Saving checkpoint")
    torch.save(state, filename)

model = Deeplabv3Plus(num_classes=1).to(DEVICE)
loss_fn = DiceBCELoss()
iou_fn = IOU()
scaler = torch.cuda.amp.GradScaler()
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)

train_iou = []
train_loss = []

for epoch in range(NUM_EPOCHS):
    print(f"Epoch: {epoch+1}/{NUM_EPOCHS}")

    iterations = 0
    iter_loss = 0.0
    iter_iou = 0.0

    batch_loop = tqdm(train_loader)
    for batch_idx,(data,targets) in enumerate(batch_loop):

        data = data.to(device = DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions , targets)
            iou = iou_fn(predictions , targets)

            iter_loss += loss.item()
            iter_iou += iou.item()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        iterations += 1 
        batch_loop.set_postfix(diceloss = loss.item(), iou = iou.item())

    train_loss.append(iter_loss / iterations)
    train_iou.append(iter_iou/iterations)
    print(f"Epoch: {epoch+1}/{NUM_EPOCHS}, Training loss: {round(train_loss[-1] , 3)}")

    checkpoint = {
    "state_dict" : model.state_dict(), 
    "optimizer" : optimizer.state_dict()
        }
    save_checkpoint(checkpoint)

    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()

    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    print(f"Dice score: {dice_score/len(val_loader)}")
    model.train()
Epoch: 1/10
100%|██████████| 2556/2556 [22:30<00:00,  1.89it/s, diceloss=0.108, iou=0.902]
Epoch: 1/10, 
Training loss: 0.24
=> Saving checkpoint
Got 222509495/232243200 with acc 95.81
Dice score: 0.9310375452041626

Epoch: 2/10
100%|██████████| 2556/2556 [22:33<00:00,  1.89it/s, diceloss=0.0504, iou=0.943]
Epoch: 2/10, Training loss: 0.136
=> Saving checkpoint
Got 225145355/232243200 with acc 96.94
Dice score: 0.9528669714927673
.
.
.
.
Epoch: 10/10
100%|██████████| 2556/2556 [22:33<00:00,  1.89it/s, diceloss=0.0361, iou=0.972]
Epoch: 10/10, Training loss: 0.042
=> Saving checkpoint
Got 226039920/232243200 with acc 97.33
Dice score: 0.9583981037139893

Fig: BCEDice loss and iou over 10 epochs

                        Fig: BCEDice loss and iou over 10 epochs

Entire training process can be found here:

Segmentation_Deeplabv3+

Testing our model:

Lets build a pipeline that blurs and removes the background from images using generated masks.

from models.deeplabv3plus import Deeplabv3Plus
import os
import torch
import torchvision
import cv2
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from models.deeplabv3plus import Deeplabv3Plus
from copy import deepcopy
DEVICE = "cuda:0"
import warnings
warnings.filterwarnings("ignore")

def resize_with_aspect_ratio(
    image, width=None, height=None, inter=cv2.INTER_AREA
):
    dim = None
    (h, w) = image.shape[:2]

    if width is None and height is None:
        return image
    if width is None:
        r = height / float(h)
        dim = (int(w * r), height)
    else:
        r = width / float(w)
        dim = (width, int(h * r))

    return cv2.resize(image, dim, interpolation=inter)


class ImageDataset(Dataset):

    def __init__(self, images: np.ndarray, transform=None) -> None:
        super(ImageDataset, self).__init__()
        self.transform = transform
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = self.images[index]
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform is not None:
            augemantations = self.transform(image=image)
            image = augemantations['image']
        return image
from typing import Union, Literal

class SegmentBackground():

    def __init__(self, model_pth: str) -> None:
        self.transforms_ = A.Compose(
            [
                A.Normalize(
                    mean=[0.0, 0.0, 0.0],
                    std=[1.0, 1.0, 1.0],
                    max_pixel_value=255.0,
                ),
                A.Resize(640, 640, p=1.0),
                ToTensorV2(),
            ],)

        self.model = Deeplabv3Plus(num_classes=1).to(DEVICE)
        state = torch.load(model_pth, map_location=DEVICE)
        self.model.load_state_dict(state['state_dict'])

    def blur_backgrond(self, image, mask):

        mask = mask[0].cpu().numpy().transpose(1, 2, 0)
        new_mapp = deepcopy(mask)
        new_mapp[mask == 0.0] = 0
        new_mapp[mask == 1.0] = 255

        orig_imginal = np.array(image)
        mapping_resized = cv2.resize(new_mapp,
                                     (orig_imginal.shape[1],
                                      orig_imginal.shape[0]),
                                     Image.ANTIALIAS)

        mapping_resized = mapping_resized.astype("uint8")
        mapping_resized = cv2.GaussianBlur(mapping_resized, (0,0), sigmaX=3, sigmaY=3, borderType = cv2.BORDER_DEFAULT)
        

        blurred = cv2.GaussianBlur(mapping_resized, (15, 15), sigmaX=0)
        _, thresholded_img = cv2.threshold(
            blurred, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
        mapping = cv2.cvtColor(thresholded_img, cv2.COLOR_GRAY2RGB)

        blurred_original_image = cv2.GaussianBlur(orig_imginal,
                                                  (151, 151), 0)
        layered_image = np.where(mapping != (0, 0, 0),
                                 orig_imginal,
                                 blurred_original_image)

        return layered_image
        
        
    def remove(self,image:np.ndarray, mask: np.ndarray) -> np.ndarray:

        mask = mask[0].cpu().numpy().transpose(1, 2, 0)
        new_mapp = deepcopy(mask)
        new_mapp[mask == 0.0] = 0
        new_mapp[mask == 1.0] = 255
        
        orig_imginal = np.array(image)
        mapping_resized = cv2.resize(new_mapp,
                                     (orig_imginal.shape[1],
                                      orig_imginal.shape[0]),
                                     Image.ANTIALIAS)

        mapping_resized = mapping_resized.astype("uint8")
        
        # mapping_resized = cv2.GaussianBlur(mapping_resized, (0,0), sigmaX=5, sigmaY=5, borderType = cv2.BORDER_DEFAULT)

        kernel = np.ones((5, 5), np.uint8)
        mapping_resized = cv2.erode(mapping_resized, kernel, iterations=1)

        print(f"Mapping Resized: {mapping_resized.shape}")
        # Extract the object using the mask
        masked_object = cv2.bitwise_and(image, image, mask=mapping_resized)
        
        background = np.where(mapping_resized==0,255,0).astype(np.uint8)
        
        finalimage = cv2.cvtColor(background,cv2.COLOR_BGR2RGB)+masked_object
        
        return finalimage
    
    
    def segement(self, image:np.ndarray, operation_type:Literal["blur","remove"] )-> np.ndarray:
        """_summary_

        Args:
            images_list (_type_): _description_
        """

        images = ImageDataset([image],
                                transform=self.transforms_)   
                     
        loader = torch.utils.data.DataLoader(
            images, batch_size=1, num_workers=1)

        self.model.eval()

        for img in loader:
            img = img.to(device=DEVICE)
            with torch.no_grad():
                preds = torch.sigmoid(self.model(img))
                mask = (preds > 0.5).float()
        
        print(f"Shape of mask",mask.shape)
        if operation_type == "blur":
            return self.blur_backgrond(image, mask)

                
        elif operation_type =="remove":

            return self.remove(image,mask)
        

        
        else:
            raise ValueError("Invalid operation_type. It must be either 'blur' or 'remove'.")
                
                
import cv2
from PIL import Image
from IPython.display import display
def display_cv2(img):

    # Convert the image from BGR to RGB (OpenCV uses BGR, but PIL uses RGB)
    image_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Convert the OpenCV image to a PIL image
    pil_image = Image.fromarray(image_rgb)

    # Display the image using IPython.display
    display(pil_image)

segmenter = SegmentBackground("resize.pth.tar")
test_pth = "pexels_test_2.jpg"

img = cv2.imread(test_pth)
img = resize_with_aspect_ratio(img,720)
segmented = segmenter.segement(img,operation_type="remove")
Shape of mask torch.Size([1, 1, 640, 640])
Mapping Resized: (728, 720)
display_cv2(segmented)

png

blurred = segmenter.segement(img,operation_type="blur")
Shape of mask torch.Size([1, 1, 640, 640])
display_cv2(blurred)

png

References:

Major Credit for Atrous convolution: