mirror of
https://github.com/xuebinqin/DIS.git
synced 2024-11-26 08:43:17 +01:00
391 lines
2.4 MiB
Plaintext
391 lines
2.4 MiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {
|
||
|
"id": "view-in-github",
|
||
|
"colab_type": "text"
|
||
|
},
|
||
|
"source": [
|
||
|
"<a href=\"https://colab.research.google.com/github/deshwalmahesh/DIS/blob/main/Colab_Demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"# Clone official repo"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "P1rhi9xgJR-x"
|
||
|
},
|
||
|
"id": "P1rhi9xgJR-x"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"! git clone https://github.com/xuebinqin/DIS\n",
|
||
|
"\n",
|
||
|
"%cd ./DIS/IS-Net\n",
|
||
|
"\n",
|
||
|
"!pip install gdown"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "wlRB0Pq0JIvF",
|
||
|
"outputId": "8b8e5619-4c39-46b6-8e3c-520c4b68cb23"
|
||
|
},
|
||
|
"id": "wlRB0Pq0JIvF",
|
||
|
"execution_count": 1,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Cloning into 'DIS'...\n",
|
||
|
"remote: Enumerating objects: 151, done.\u001b[K\n",
|
||
|
"remote: Counting objects: 100% (116/116), done.\u001b[K\n",
|
||
|
"remote: Compressing objects: 100% (83/83), done.\u001b[K\n",
|
||
|
"remote: Total 151 (delta 35), reused 108 (delta 30), pack-reused 35\u001b[K\n",
|
||
|
"Receiving objects: 100% (151/151), 43.23 MiB | 34.50 MiB/s, done.\n",
|
||
|
"Resolving deltas: 100% (37/37), done.\n",
|
||
|
"/content/DIS/IS-Net\n",
|
||
|
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
|
||
|
"Requirement already satisfied: gdown in /usr/local/lib/python3.7/dist-packages (4.4.0)\n",
|
||
|
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.7/dist-packages (from gdown) (4.6.3)\n",
|
||
|
"Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from gdown) (1.15.0)\n",
|
||
|
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from gdown) (4.64.0)\n",
|
||
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from gdown) (3.7.1)\n",
|
||
|
"Requirement already satisfied: requests[socks] in /usr/local/lib/python3.7/dist-packages (from gdown) (2.23.0)\n",
|
||
|
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (3.0.4)\n",
|
||
|
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.24.3)\n",
|
||
|
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2022.6.15)\n",
|
||
|
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (2.10)\n",
|
||
|
"Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.7/dist-packages (from requests[socks]->gdown) (1.7.1)\n"
|
||
|
]
|
||
|
}
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"# Imports"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "RO0DY6O3Jqe9"
|
||
|
},
|
||
|
"id": "RO0DY6O3Jqe9"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"import numpy as np\n",
|
||
|
"from PIL import Image\n",
|
||
|
"import torch\n",
|
||
|
"from torch.autograd import Variable\n",
|
||
|
"from torchvision import transforms\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"import gdown\n",
|
||
|
"import os\n",
|
||
|
"\n",
|
||
|
"import requests\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from io import BytesIO\n",
|
||
|
"\n",
|
||
|
"# project imports\n",
|
||
|
"from data_loader_cache import normalize, im_reader, im_preprocess \n",
|
||
|
"from models import *\n"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "9fFNd2X_Js0e",
|
||
|
"outputId": "80e92881-6893-4227-a7a7-4f39a566061f"
|
||
|
},
|
||
|
"id": "9fFNd2X_Js0e",
|
||
|
"execution_count": 2,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"/usr/local/lib/python3.7/dist-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.\n",
|
||
|
" warnings.warn(warning.format(ret))\n"
|
||
|
]
|
||
|
}
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"# Helpers"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "h1C9zSdkJgtF"
|
||
|
},
|
||
|
"id": "h1C9zSdkJgtF"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
||
|
"\n",
|
||
|
"# Download official weights\n",
|
||
|
"if not os.path.exists(\"./saved_models\"):\n",
|
||
|
" !mkdir ./saved_models\n",
|
||
|
" MODEL_PATH_URL = \"https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn\"\n",
|
||
|
" gdown.download(MODEL_PATH_URL, \"./saved_models/isnet.pth\", use_cookies=False)\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class GOSNormalize(object):\n",
|
||
|
" '''\n",
|
||
|
" Normalize the Image using torch.transforms\n",
|
||
|
" '''\n",
|
||
|
" def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):\n",
|
||
|
" self.mean = mean\n",
|
||
|
" self.std = std\n",
|
||
|
"\n",
|
||
|
" def __call__(self,image):\n",
|
||
|
" image = normalize(image,self.mean,self.std)\n",
|
||
|
" return image\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])\n",
|
||
|
"\n",
|
||
|
"def load_image(im_path, hypar):\n",
|
||
|
" if im_path.startswith(\"http\"):\n",
|
||
|
" im_path = BytesIO(requests.get(im_path).content)\n",
|
||
|
"\n",
|
||
|
" im = im_reader(im_path)\n",
|
||
|
" im, im_shp = im_preprocess(im, hypar[\"cache_size\"])\n",
|
||
|
" im = torch.divide(im,255.0)\n",
|
||
|
" shape = torch.from_numpy(np.array(im_shp))\n",
|
||
|
" return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def build_model(hypar,device):\n",
|
||
|
" net = hypar[\"model\"]#GOSNETINC(3,1)\n",
|
||
|
"\n",
|
||
|
" # convert to half precision\n",
|
||
|
" if(hypar[\"model_digit\"]==\"half\"):\n",
|
||
|
" net.half()\n",
|
||
|
" for layer in net.modules():\n",
|
||
|
" if isinstance(layer, nn.BatchNorm2d):\n",
|
||
|
" layer.float()\n",
|
||
|
"\n",
|
||
|
" net.to(device)\n",
|
||
|
"\n",
|
||
|
" if(hypar[\"restore_model\"]!=\"\"):\n",
|
||
|
" net.load_state_dict(torch.load(hypar[\"model_path\"]+\"/\"+hypar[\"restore_model\"],map_location=device))\n",
|
||
|
" net.to(device)\n",
|
||
|
" net.eval() \n",
|
||
|
" return net\n",
|
||
|
"\n",
|
||
|
" \n",
|
||
|
"def predict(net, inputs_val, shapes_val, hypar, device):\n",
|
||
|
" '''\n",
|
||
|
" Given an Image, predict the mask\n",
|
||
|
" '''\n",
|
||
|
" net.eval()\n",
|
||
|
"\n",
|
||
|
" if(hypar[\"model_digit\"]==\"full\"):\n",
|
||
|
" inputs_val = inputs_val.type(torch.FloatTensor)\n",
|
||
|
" else:\n",
|
||
|
" inputs_val = inputs_val.type(torch.HalfTensor)\n",
|
||
|
"\n",
|
||
|
" \n",
|
||
|
" inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable\n",
|
||
|
" \n",
|
||
|
" ds_val = net(inputs_val_v)[0] # list of 6 results\n",
|
||
|
"\n",
|
||
|
" pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction\n",
|
||
|
"\n",
|
||
|
" ## recover the prediction spatial size to the orignal image size\n",
|
||
|
" pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))\n",
|
||
|
"\n",
|
||
|
" ma = torch.max(pred_val)\n",
|
||
|
" mi = torch.min(pred_val)\n",
|
||
|
" pred_val = (pred_val-mi)/(ma-mi) # max = 1\n",
|
||
|
"\n",
|
||
|
" if device == 'cuda': torch.cuda.empty_cache()\n",
|
||
|
" return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "BFVvxhZQJkEy",
|
||
|
"outputId": "e8271029-16de-4501-a579-c04bd16aa420"
|
||
|
},
|
||
|
"id": "BFVvxhZQJkEy",
|
||
|
"execution_count": 3,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Downloading...\n",
|
||
|
"From: https://drive.google.com/uc?id=1KyMpRjewZdyYfxHPYcd-ZbanIXtin0Sn\n",
|
||
|
"To: /content/DIS/IS-Net/saved_models/isnet.pth\n",
|
||
|
"100%|██████████| 177M/177M [00:03<00:00, 56.6MB/s]\n"
|
||
|
]
|
||
|
}
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"# Set Parameters"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "H7OQxVqaOgtk"
|
||
|
},
|
||
|
"id": "H7OQxVqaOgtk"
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "189b719a-c9a2-4048-8620-0501fd5653ec",
|
||
|
"metadata": {
|
||
|
"id": "189b719a-c9a2-4048-8620-0501fd5653ec"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"hypar = {} # paramters for inferencing\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"hypar[\"model_path\"] =\"./saved_models\" ## load trained weights from this path\n",
|
||
|
"hypar[\"restore_model\"] = \"isnet.pth\" ## name of the to-be-loaded weights\n",
|
||
|
"hypar[\"interm_sup\"] = False ## indicate if activate intermediate feature supervision\n",
|
||
|
"\n",
|
||
|
"## choose floating point accuracy --\n",
|
||
|
"hypar[\"model_digit\"] = \"full\" ## indicates \"half\" or \"full\" accuracy of float number\n",
|
||
|
"hypar[\"seed\"] = 0\n",
|
||
|
"\n",
|
||
|
"hypar[\"cache_size\"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size\n",
|
||
|
"\n",
|
||
|
"## data augmentation parameters ---\n",
|
||
|
"hypar[\"input_size\"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar[\"cache_size\"], which means we don't further resize the images\n",
|
||
|
"hypar[\"crop_size\"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar[\"cache_size\"], e.g., [920,920] for data augmentation\n",
|
||
|
"\n",
|
||
|
"hypar[\"model\"] = ISNetDIS()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "0af5269e-26a6-4370-8863-92b7381ee90f",
|
||
|
"metadata": {
|
||
|
"tags": [],
|
||
|
"id": "0af5269e-26a6-4370-8863-92b7381ee90f"
|
||
|
},
|
||
|
"source": [
|
||
|
"# Build Model"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "b23ea487-1f64-4443-95b4-7998b5345310",
|
||
|
"metadata": {
|
||
|
"id": "b23ea487-1f64-4443-95b4-7998b5345310"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"net = build_model(hypar, device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "8beb1f62-0345-4c82-a2e3-9a4db55a55a2",
|
||
|
"metadata": {
|
||
|
"id": "8beb1f62-0345-4c82-a2e3-9a4db55a55a2"
|
||
|
},
|
||
|
"source": [
|
||
|
"# Predict Mask"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "036b21e8-556b-43dd-b9fb-1ea085f7f0f1",
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/",
|
||
|
"height": 753
|
||
|
},
|
||
|
"id": "036b21e8-556b-43dd-b9fb-1ea085f7f0f1",
|
||
|
"outputId": "1ce15be9-9287-498a-82ec-637b4473abb2"
|
||
|
},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3722: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\n",
|
||
|
" warnings.warn(\"nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\")\n",
|
||
|
"/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1960: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
|
||
|
" warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "display_data",
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<Figure size 2520x1440 with 2 Axes>"
|
||
|
],
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB88AAAOnCAYAAACu9Qd2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdebBt2V3Y9+9aez77THd694393uun14Ok1tRNSyBhhAZAIAphBYORy+AkRWwMqYoxqZiUCaQSXBBXHBJimaGwY2FbECEIYA0ISuoWSELdrW6p59fTm+88nHHPe+WPtc95txs00nrdtH6fqu537xn2uPZw92/9fksZYxBCCCGEEEIIIYQQQgghhBBCCCG+kekXegGEEEIIIYQQQgghhBBCCCGEEEKIF5oEz4UQQgghhBBCCCGEEEIIIYQQQnzDk+C5EEIIIYQQQgghhBBCCCGEEEKIb3gSPBdCCCGEEEIIIYQQQgghhBBCCPENT4LnQgghhBBCCCGEEEIIIYQQQgghvuFJ8FwIIYQQQgghhBBCCCGEEEIIIcQ3PAmeCyGEEAcopX5GKfUbz/dnv4JpGaXUy56PaQkhhBBCCCGEEEJ8vSilPqGU+q9f6OUQQgghvh4keC6EEOIlSyn1o0qpB5VSU6XUulLqvUqp/pf6jjHmF4wxX9EfgF/NZ/865I9SIYQQQgghhBBCfDWUUueVUrlSavk5r9/fdOA/9cIsmRBCCPHiJsFzIYQQL0lKqZ8CfhH4aaAHvAE4CXxMKeV/ke+4128JhRBCCCGEEEIIIb6ungH+7uwXpdRtQOuFWxwhhBDixU+C50IIIV5ylFJd4OeBnzTGfMQYUxhjzgN/BzgF/L3mcz+nlPqAUuq3lFJD4Eeb137rwLT+vlLqglJqRyn1z5ue22878P3fan4+1fTc/hGl1EWl1LZS6n88MJ07lVKfVkrtK6XWlFK/8sWC+F9m3d6slLqslPrvlVKbzbTepZT6bqXUOaXUrlLqZ77S+SqlvkMp9bhSaqCU+tdKqbsOZrkrpf5LpdSjSqk9pdRHlVInv9plFkIIIYQQQgghxAvifcDfP/D7jwD/fvaLUup7mkz0oVLqklLq5w68FzbPS3aaZwr3KKVWnzsDpdQRpdQXlFI//fVcESGEEOJ6keC5EEKIl6JvAULggwdfNMaMgQ8Bbz/w8vcBHwD6wH84+Hml1MuBfw28BziCzWA/9mXm/SbgZuCtwM8qpW5tXq+A/w5YBr65ef/Hv8r1mjmMXb9jwM8Cv47tEHA78K3AP1dKnf5y821Kt30A+GfAEvA4dtvRvP99wM8AfxtYAT4J/KevcZmFEEIIIYQQQghxfX0G6CqlblVKOcAPAb914P0JNrjeB74H+EdKqXc17/0I9jnICewzg38IJAcn3jx7uAv4FWPM//b1XBEhhBDiepHguRBCiJeiZWDbGFP+Fe+tNe/PfNoY8/vGmNoYkzzns/8F8IfGmD8zxuTYQLX5MvP+eWNMYoz5PPB54NUAxpj7jDGfMcaUTRb8rwLf9tWvGgAF8L8aYwrg/c36/LIxZmSMeRh45Cuc73cDDxtjPthsq/8TWD8wn38I/AtjzKPN+78AvEayz4UQQgghhBBCiL8xZtnnbwceBa7M3jDGfMIY82DzTOQL2A7zs2cGBTZo/jJjTNU8XxgemO7LgY8D/5Mx5teux4oIIYQQ14MEz4UQQrwUbQPLX2QM8yPN+zOXvsR0jh583xgzBXa+zLwPBp+nQBtAKXWTUuqPlFLrTYn4X+DZQfyvxo4xpmp+ngX8Nw68n3yF833u+hng8oHpnAR+uSnPtg/sAoovn30vhBBCCCGEEEKIF4f3AT8M/CgHSrYDKKVer5T6uFJqSyk1wHaiXz7wvY8C71dKXVVK/ZJSyjvw9fdgA/Ef+HqvgBBCCHE9SfBcCCHES9GngQxbbnxOKdUG3gH86YGXv1Qm+Rpw/MD3I2yv66/Fe4HHgLPGmC62HLr6Gqf1fM33ueunDv6ODaz/N8aY/oH/ImPMp67DcgshhBBCCCGEEOKvyRhzAXgGW33ug895+z8CfwCcMMb0gH9D88zAGFMYY37eGPNy7BBv7+TZ46f/HDY54T82JeGFEEKIlwQJngshhHjJMcYMgJ8H/i+l1HcppTyl1Cngd7CZ1e/7Cif1AeB7lVLfopTysX8Yfq0B7w4wBMZKqVuAf/Q1Tuf5nO9/Bm5TSr2rydL/x9jx1Gf+DfDPlFKvAFBK9ZRSP3CdllsIIYQQQgghhBDPj/8KeIsxZvKc1zvArjEmVUrdic1QB0Ap9e1KqduawPgQW8a9PvDdAvgBIAb+vVJKYg1CCCFeEuSCJoQQ4iXJGPNL2Czrf4n9I+8vsJnUbzXGZF/hNB4GfhI7rvgaMAY2sVntX61/iv0jdAT8OvDbX8M0vhZfdL7GmG3sH7q/hC1H/3LgXpr1M8b8HvCL2BJtQ+AhbOa+EEIIIYQQQggh/oYwxjxljLn3r3jrx4H/WSk1An4Wm3QwcxibVDDEjpV+F89JRjDG5Niqf6vAb0oAXQghxEuBssObCiGEEOLLacq+72NLoD/zQi/P8635I/cy8B5jzMdf6OURQgghhBBCCCGEEEIIIa4n6QkmhBBCfAlKqe9VSrWUUjE2i/1B4PwLu1TPH6XUdyql+kqpgGvjoX/mBV4sIYQQQgghhBBCCCGEEOK6u+7B82bs2ceVUk8qpf6H6z1/IYQQ4qv0fcDV5r+zwA+Zl1bZlm8GngK2ge8F3mWMSV7YRRJCCCGEuD7kGYUQQgghhBBCiIOua9l2pZQDnAPeji0Lew/wd40xj1y3hRBCCCGEEEIIIcQ3PHlGIYQQQgghhBDiua535vmdwJPGmKeNMTnwfmxGnxBCCCGEEEIIIcT1JM8ohBBCCCGEEEI8y/UOnh8DLh34/XLzmhBCCCGEEEIIIcT1JM8ohBBCCCGEEEI8i/tCL8BzKaV+DPgxgDiOb7/lllte4CX64mpjeObiJcqyoq4NQeDjui6TyQTP83EcTZqmGAOOo1FK8dwy+UopfN+nrmqMMdTG4LkuZVmilEJrTVmWaEeDgaIoqOuKVismLwrKsrQTMoa4HeO6HlmaUpsapTRVVREEAZ7rEgQB7XZMnhcURUEUhvieQ1Ub9vZGZHkOBuq6xnUUYRjgeR5pluN7HmHo4zouxtQ4rkYrhVKQZgWDwQhjauI4JgwDtKMpyoqtnW2myZTZanuuCwpacYtWFDEcjciynKqq8TzbHO26AwYMBoyhriuMMXieT13XVFWFUoq6rtFaN5vAgAIqOzOtHeqqojYGMBhjmn1gv+P7PmDQWjXvgVJ6vu2LvAAgbsfUdU1ZFChtP9MsHhhFbQx1XVPXtV0ODKY2eJ6HUoqqrtDawXUc8iJHoeaf1VqjtKauKpTW6Gb+ZVWCAaXtvFDgaG3XxRhM03Ycx7Hbx5h525otm8J+39Q1ZVk2rykcR+M4LsYYqqpEKW3bXl3jOA51Xc3bnt1GUFUlxhjKssRz3Xm7cl2XoigwBrSjcRzH/l4bTF1jmnW0+xEcx6GsKoqiwHVdFhcX2d/fpyiKZtnVtf1UG8IoYHFxkel0ymQyfdb+dl2XsqzmP6MUqmk/s/YxPz7swTbfVgBRFDVtz+5vUE0bUNS1wXF0s43svpodw2VV4Wi7Deu6elZbPHh429mZ5gdmLWb+uzF2HvO223ykruv5uhw8W8z2t6M1ZVXZNtPsu1k7dhynOR/VzZf0fNlms26322ilKcqCsiioazOfjzH1tcaDXRbPc3EcjdaaoiioKkOr1cL3fTzXJwxDpklCkWX0F/qUVcn21jZhFFCWJUmSzNfPzNpusz6zfT07N2itKZrj3+5n1RyfHFgH1XzHUJbVbFGb40dhgKpp767jgFLoZh6zfVXVNar5ptJ2erPf7TKa5vxzbb4HzZbbHi/Xtu1smx08ppi1bHXtPGO3h5qvi2mOa621XV6tbXtqvl8bu7xKKVD2OJkfV83y6GZfX1s+e41S2HMh2LastGqOTfudWVtzXffaMdPMY
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
}
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"image_path = \"https://i5.walmartimages.com/asr/43995148-22bf-4836-b6d3-e8f64a73be54.5398297e6f59fc510e0111bc6ff3a02a.jpeg\"\n",
|
||
|
"image_bytes = BytesIO(requests.get(image_path).content)\n",
|
||
|
"\n",
|
||
|
"image_tensor, orig_size = load_image(image_path, hypar) \n",
|
||
|
"mask = predict(net,image_tensor,orig_size, hypar, device)\n",
|
||
|
"\n",
|
||
|
"f, ax = plt.subplots(1,2, figsize = (35,20))\n",
|
||
|
"\n",
|
||
|
"ax[0].imshow(np.array(Image.open(image_bytes))) # Original image\n",
|
||
|
"ax[1].imshow(mask, cmap = 'gray') # retouched image\n",
|
||
|
"\n",
|
||
|
"ax[0].set_title(\"Original Image\")\n",
|
||
|
"ax[1].set_title(\"Mask\")\n",
|
||
|
"\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.11"
|
||
|
},
|
||
|
"colab": {
|
||
|
"name": "DIS Demo.ipynb",
|
||
|
"provenance": [],
|
||
|
"collapsed_sections": [],
|
||
|
"include_colab_link": true
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|