{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"source": [
"# Clone official repo"
],
"metadata": {
"id": "P1rhi9xgJR-x"
},
"id": "P1rhi9xgJR-x"
},
{
"cell_type": "code",
"source": [
"! git clone https://github.com/xuebinqin/DIS\n",
"\n",
"%cd ./DIS/IS-Net\n",
"\n",
"!pip install gdown"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wlRB0Pq0JIvF",
"outputId": "8b8e5619-4c39-46b6-8e3c-520c4b68cb23"
},
"id": "wlRB0Pq0JIvF",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into 'DIS'...\n",
"remote: Enumerating objects: 151, done.\u001b[K\n",
"remote: Counting objects: 100% (116/116), done.\u001b[K\n",
"remote: Compressing objects: 100% (83/83), done.\u001b[K\n",
"remote: Total 151 (delta 35), reused 108 (delta 30), pack-reused 35\u001b[K\n",
"Receiving objects: 100% (151/151), 43.23 MiB | 34.50 MiB/s, done.\n",
"Resolving deltas: 100% (37/37), done.\n",
"/content/DIS/IS-Net\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: gdown in /usr/local/lib/python3.7/dist-packages (4.4.0)\n",
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.7/dist-packages (from gdown) (4.6.3)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from gdown) (1.15.0)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from gdown) (4.64.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from gdown) (3.7.1)\n",
"Requirement already satisfied: requests[socks] in /usr/local/lib/python3.7/dist-packages (from gdown) (2.23.0)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2022.6.15)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2.10)\n",
"Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.7.1)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Imports"
],
"metadata": {
"id": "RO0DY6O3Jqe9"
},
"id": "RO0DY6O3Jqe9"
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"from PIL import Image\n",
"import torch\n",
"from torch.autograd import Variable\n",
"from torchvision import transforms\n",
"import torch.nn.functional as F\n",
"import gdown\n",
"import os\n",
"\n",
"import requests\n",
"import matplotlib.pyplot as plt\n",
"from io import BytesIO\n",
"\n",
"# project imports\n",
"from data_loader_cache import normalize, im_reader, im_preprocess \n",
"from models import *\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9fFNd2X_Js0e",
"outputId": "80e92881-6893-4227-a7a7-4f39a566061f"
},
"id": "9fFNd2X_Js0e",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n",
" warnings.warn(warning.format(ret))\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Helpers"
],
"metadata": {
"id": "h1C9zSdkJgtF"
},
"id": "h1C9zSdkJgtF"
},
{
"cell_type": "code",
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
"# Download official weights\n",
"if not os.path.exists(\"./saved_models\"):\n",
" !mkdir ./saved_models\n",
" MODEL_PATH_URL = \"https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn\"\n",
" gdown.download(MODEL_PATH_URL, \"./saved_models/isnet.pth\", use_cookies=False)\n",
"\n",
"\n",
"class GOSNormalize(object):\n",
" '''\n",
" Normalize the Image using torch.transforms\n",
" '''\n",
" def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):\n",
" self.mean = mean\n",
" self.std = std\n",
"\n",
" def __call__(self,image):\n",
" image = normalize(image,self.mean,self.std)\n",
" return image\n",
"\n",
"\n",
"transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])\n",
"\n",
"def load_image(im_path, hypar):\n",
" if im_path.startswith(\"http\"):\n",
" im_path = BytesIO(requests.get(im_path).content)\n",
"\n",
" im = im_reader(im_path)\n",
" im, im_shp = im_preprocess(im, hypar[\"cache_size\"])\n",
" im = torch.divide(im,255.0)\n",
" shape = torch.from_numpy(np.array(im_shp))\n",
" return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape\n",
"\n",
"\n",
"def build_model(hypar,device):\n",
" net = hypar[\"model\"]#GOSNETINC(3,1)\n",
"\n",
" # convert to half precision\n",
" if(hypar[\"model_digit\"]==\"half\"):\n",
" net.half()\n",
" for layer in net.modules():\n",
" if isinstance(layer, nn.BatchNorm2d):\n",
" layer.float()\n",
"\n",
" net.to(device)\n",
"\n",
" if(hypar[\"restore_model\"]!=\"\"):\n",
" net.load_state_dict(torch.load(hypar[\"model_path\"]+\"/\"+hypar[\"restore_model\"],map_location=device))\n",
" net.to(device)\n",
" net.eval() \n",
" return net\n",
"\n",
" \n",
"def predict(net, inputs_val, shapes_val, hypar, device):\n",
" '''\n",
" Given an Image, predict the mask\n",
" '''\n",
" net.eval()\n",
"\n",
" if(hypar[\"model_digit\"]==\"full\"):\n",
" inputs_val = inputs_val.type(torch.FloatTensor)\n",
" else:\n",
" inputs_val = inputs_val.type(torch.HalfTensor)\n",
"\n",
" \n",
" inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable\n",
" \n",
" ds_val = net(inputs_val_v)[0] # list of 6 results\n",
"\n",
" pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction\n",
"\n",
" ## recover the prediction spatial size to the orignal image size\n",
" pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))\n",
"\n",
" ma = torch.max(pred_val)\n",
" mi = torch.min(pred_val)\n",
" pred_val = (pred_val-mi)/(ma-mi) # max = 1\n",
"\n",
" if device == 'cuda': torch.cuda.empty_cache()\n",
" return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "BFVvxhZQJkEy",
"outputId": "e8271029-16de-4501-a579-c04bd16aa420"
},
"id": "BFVvxhZQJkEy",
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Downloading...\n",
"From: https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn\n",
"To: /content/DIS/IS-Net/saved_models/isnet.pth\n",
"100%|██████████| 177M/177M [00:03<00:00, 56.6MB/s]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Set Parameters"
],
"metadata": {
"id": "H7OQxVqaOgtk"
},
"id": "H7OQxVqaOgtk"
},
{
"cell_type": "code",
"execution_count": 4,
"id": "189b719a-c9a2-4048-8620-0501fd5653ec",
"metadata": {
"id": "189b719a-c9a2-4048-8620-0501fd5653ec"
},
"outputs": [],
"source": [
"hypar = {} # paramters for inferencing\n",
"\n",
"\n",
"hypar[\"model_path\"] =\"./saved_models\" ## load trained weights from this path\n",
"hypar[\"restore_model\"] = \"isnet.pth\" ## name of the to-be-loaded weights\n",
"hypar[\"interm_sup\"] = False ## indicate if activate intermediate feature supervision\n",
"\n",
"## choose floating point accuracy --\n",
"hypar[\"model_digit\"] = \"full\" ## indicates \"half\" or \"full\" accuracy of float number\n",
"hypar[\"seed\"] = 0\n",
"\n",
"hypar[\"cache_size\"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size\n",
"\n",
"## data augmentation parameters ---\n",
"hypar[\"input_size\"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar[\"cache_size\"], which means we don't further resize the images\n",
"hypar[\"crop_size\"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar[\"cache_size\"], e.g., [920,920] for data augmentation\n",
"\n",
"hypar[\"model\"] = ISNetDIS()"
]
},
{
"cell_type": "markdown",
"id": "0af5269e-26a6-4370-8863-92b7381ee90f",
"metadata": {
"tags": [],
"id": "0af5269e-26a6-4370-8863-92b7381ee90f"
},
"source": [
"# Build Model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b23ea487-1f64-4443-95b4-7998b5345310",
"metadata": {
"id": "b23ea487-1f64-4443-95b4-7998b5345310"
},
"outputs": [],
"source": [
"net = build_model(hypar, device)"
]
},
{
"cell_type": "markdown",
"id": "8beb1f62-0345-4c82-a2e3-9a4db55a55a2",
"metadata": {
"id": "8beb1f62-0345-4c82-a2e3-9a4db55a55a2"
},
"source": [
"# Predict Mask"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "036b21e8-556b-43dd-b9fb-1ea085f7f0f1",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 753
},
"id": "036b21e8-556b-43dd-b9fb-1ea085f7f0f1",
"outputId": "1ce15be9-9287-498a-82ec-637b4473abb2"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3722: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\n",
" warnings.warn(\"nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\")\n",
"/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1960: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
" warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"