adm18/IMPAX/nni/nb1.ipynb

236 lines
773 KiB
Text
Raw Normal View History

2025-09-16 05:20:19 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import argparse\n",
"import logging\n",
"import nni\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torchvision import datasets, transforms"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import re"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from torchsummary import summary\n",
"\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from dataset import *\n",
"from models import *"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"TEST_STEP = 5\n",
"\n",
"testset = IMPAXDataset('/shares/Public/IMPAX/test')\n",
"\n",
"testloader = torch.utils.data.DataLoader(\n",
" testset, \n",
" batch_size=TEST_STEP,\n",
" shuffle=True, \n",
" num_workers=6)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n",
"\n",
"GeForce RTX 2080 Ti\n",
"Memory Usage:\n",
"Allocated: 0.0 GB\n",
"Cached: 0.0 GB\n"
]
}
],
"source": [
"# setting device on GPU if available, else CPU\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print('Using device:', device)\n",
"print()\n",
"\n",
"#Additional Info when using cuda\n",
"if device.type == 'cuda':\n",
" print(torch.cuda.get_device_name(0))\n",
" print('Memory Usage:')\n",
" print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')\n",
" print('Cached: ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')\n",
"\n",
"# device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 256, 256] 640\n",
" Conv2d-2 [-1, 64, 256, 256] 36,928\n",
" Conv2d-3 [-1, 64, 256, 256] 36,928\n",
" Conv2d-4 [-1, 64, 256, 256] 36,928\n",
" Conv2d-5 [-1, 64, 256, 256] 36,928\n",
" Conv2d-6 [-1, 64, 256, 256] 36,928\n",
" Conv2d-7 [-1, 2, 256, 256] 1,154\n",
"================================================================\n",
"Total params: 186,434\n",
"Trainable params: 186,434\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.25\n",
"Forward/backward pass size (MB): 193.00\n",
"Params size (MB): 0.71\n",
"Estimated Total Size (MB): 193.96\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"PATH = '/home/xfr/nni/model-5-64/TwNuKtj7/best_zdoyO.pth'\n",
"\n",
"m = re.search('model-(\\d*)-(\\d*)', PATH)\n",
"\n",
"hidden_layer = int(m[1])\n",
"hidden_size = int(m[2])\n",
"\n",
"model = Net(hidden_layer, hidden_size)\n",
"model.load_state_dict(torch.load(PATH))\n",
"model.to(device)\n",
"model.eval()\n",
"\n",
"summary(model, (1,256,256))\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABIkAAAYpCAYAAADSMUyfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9y49l2VU++J37ft8bz4zId1VWlV1FYYsqbBmDEG3pJxAS/CQGrf79hkgwoSeM8J/QEmMG7UGjlhh0j3qEJRo16gGYgQsbSwa3wVXprHzEO+K+3/ecHkR9O76zYt94ZEZk3LTPkkIRce85+732Wuvba60dRFGEhBJKKKGEEkoooYQSSiihhBJKKKGEfrkpddMNSCihhBJKKKGEEkoooYQSSiihhBJK6OYpAYkSSiihhBJKKKGEEkoooYQSSiihhBJKQKKEEkoooYQSSiihhBJKKKGEEkoooYQSkCihhBJKKKGEEkoooYQSSiihhBJKKCEkIFFCCSWUUEIJJZRQQgkllFBCCSWUUEJIQKKEEkoooYQSSiihhBJKKKGEEkoooYRwjSBREAS/FwTBT4Mg+FkQBN++rnoSSiihi1PClwkltJiU8GZCCS0mJbyZUEKLSQlvJpTQ9VEQRdHVFxoEaQD/AeC/AHgG4PsA/lsURf9+5ZUllFBCF6KELxNKaDEp4c2EElpMSngzoYQWkxLeTCih66Xr8iT6OoCfRVH0WRRFYwD/B4D/ek11JZRQQhejhC8TSmgxKeHNhBJaTEp4M6GEFpMS3kwooWuk6wKJ7gB4Kv8/++KzhBJK6OYo4cuEElpMSngzoYQWkxLeTCihxaSENxNK6Bopc1MVB0HwpwD+FADS6fTH5XIZQRB4n02lUu47/h0EQexzAIiiCFEUue/nEZ/TUDv+P5vNTn2v5fEz+73+1uc8/faWZ9+3ZWr/9bev/HmfzWufbat9nt+HYRgbY1998z5X8vU5CALv2Nr/fc+f1w99dzKZYDabYTqdIgxDV85sNnPvpVIppFIp5HI5rK+vo1AoIJPJYDAYYDgcYjQaYTgc4tatW/jJT36CTCbjyvfN60VoMpnsR1G0dqmXromUN1Op1MeFQmEuT+l6Jj9y/JTmrXP9zK4zfZdzZL+zZdm1o2207bBr6ov++oZkLm/avy1/X7SseWT7w/Zpf+yPbaOt+7K8qRSG4Zn/++bjojSZTDCdTjGdTmN7Afdkrq10Oo1cLoeNjQ0Ui0XHm/1+H6PRCIPBANPpFLPZDOPx+Mw26F4273sAC8ublUrlSuSmlD+37peVmzqX+tvHB+fJzXnf2b8vKzdtOb4ybPvOkptWlr3JchOA40vuwUEQIAzDGP8HQeB489atWygUCkin044vh8MhOp2OWyvK15elKIq4T1z+5Wsi5U0AH99kWxJK6GXo448/xr/8y79cSVmLzJvZbPbCcu88etk9jO8qXVR/epX2JvT6ycphnw50FuVyOTx8+BA//elPY59f1t4MgoC274UX0HWBRM8B3JP/737xmaMoir4D4DsAsLS0FP3u7/4uwjB0hgKV+zAMkcvlUCqVUCgUQKU4m82i0Wggm806hYVKx2g0cspLJpNBOp12g6nK7BftcPVMJhP0ej30+32EYeiABBIVpOl06j5jvcCxAloqlZyRTGOHSpW+o4Y0FWC2UxV5bX82m0U2m40Z4+yDvpdOp2Pv8zM7Dkqq8PG5VCqF2WzmfsbjsTPS2X6Wrco4fxQ8UWPCEttK8Ib1sZ3D4dDNVz6fRyaT8YKDCmKxL5YZnz59imaziU6n49ozHA7RbrddW6rVKu7evYuvfOUr+LM/+zPcv38fAPDP//zP+M///E88efIEP/3pT/Hnf/7n+K3f+i08ePAAzWYTYRiiVCphPB7H2qb1W0OB329tbT05NTBXT+fy5RdtdLxZq9Wib37zm25OJpNJbD6KxSIKhQKKxSKKxSKy2SxyuZzjzVQq5XhkNps53oyiCNls1q1LzhnnnTyi66HT6WA4HLo9gnPM8izP6nogb7I+BQvJI8rralSTN6Mocs/qGieP5nI5pNPp2Fha41x5hnzKMuwe9cVcxAAy5X/lzdFo5N5je1gPiXXYz1OplKuHe5vuKel0Ojb/HKfZbIZer+fGN5fLubo5/izLAj9qNJMPnj59iv39fXS7XTfv/X4f3W7XtbNareLBgwf4tV/7NfzFX/wFHjx4gCiK8I//+I/4yU9+gsePH+PHP/4x9vf3cXR0hKdPn8b4UeksgU2ZAgDT6XQhebNer0e//du/zTbGwO8oipBOp1EoFFAoFFAul5FOp5HJZFCr1Zzc1LVFngDg1jnnR+WD3WOn0yn6/T6Gw6EDE5SX+K7KTV1rqVTKgX1UYlRusp3cEyww65Obyl/pdDrGN3Yvtu+pDLXyzQIxOi58XsfLyk1bB0n3Bi3HyjUlyn/dX/ms1huGoTvoIK9rX1Qm697JMY6iCNvb22i32+j1em5+R6OR400AqFQq2NjYwAcffIA/+ZM/webmJqbTKX784x/j888/d7zZbDbRbrfRbrfR7/cxnU5duywY6JOb7Pve3p5lj+uiS/NmEARXn+wzoYSumT755JM3DYC4NG9mMploeXn5UuDOZQ67LkOLMNbX1bdfdqLstPJs3md2HlT/WVtbw9/8zd/g93//952e7zuIt+Wrbst3Lis3rwsk+j6Ad4MgeAvHDPs/Afjv8x6mkkYlUQEDVbb4Q4WPyh+VOD3lojFjjTat03olsMx8Po/pdHpK2VNlkWVEURQznADEFEx9bjqdOgPXAjdU4BX8sQrobDZzv+1prRrkqgha7x9VfKkcqgFnwSyOqSqhAJyRoUahJQXHfEygz2mbSTRK2c55G6rW4QO/OB/T6RSj0cjNLeugscLPMpkMKpUKlpaWUC6XAQDj8RiDwcDNNceVnkW5XM7Ny3Q6BU8qLnoy/JroUnwJwK2tTCYTWy/kzfF4jEwmc8qjJJPJON7kuuR64HpTA5A/XJcK2ui8cN0p+AScjCWBSX6m8wvA7TOsh2uXBtV4PI55q/Ad3WdYtoIf7M9kMomNg+5RymPqxca/2Rb2jd/xRw1k5U39Ww1k5Xe+4+NVBX18RrQ1HFmf8oIKLBVI3Jd1jhRk4LgQrFVeJLCg6yObzaJcLmN1dRWVSsW1p9/vu4MFBQ7mCVJtjxXSKmhfI49emjcBOICR65xgoZUx5J1MJoNCoYB8Pu/WAueC408ZCyC23iz4yne5t5K3WYYCHLr2VB4qKEXeZt0kC04qiGPlppWL/J77gpWF+j7bp4cs3C/IR7oWJ5NJbH0p0MvvuP75bj6fPwUG6eGMAkTa3nnrUMtXMJfjxLII4Fp+IDio5ZMf9bCLc6X7JfcLvpPNZlGtVrG2tuZk52g0is0H5zmbzcZ0k8lk4sY7l8ud6q+Vo6/ZuHop3kwooYSunV5Kp+Xvi+4jiwDmXBf9Ivftpsjq0CrPLIYwbz1a72efF7QPYPIdsti/L0PXAhJFUTQNguB/BvB3ANIA/rcoiv7tjOcxmUycEkMFwj5D44RKMZVhKsJULlKplFN8dIIUjFCgADgxzFg/lUY9MQVOQAVVmNPpdMwItoqoLgIqYKqI8js1blgHFTGCQz4PGVWCdaysd4Qqj/ZZbYMFitQQJakBNu/0j+X7vIcsaRuo7FoQz/aN9ahRrf1TkIrK9Hg8dgaUei5Zj4tsNotCoeDWCo1VNbq1nuFwiHQ6jVarhUaj4ZRebbcdm9dNl+XLL96JGUO5XC7mQQAgZsyr10Aul3NKP3Cy/hTkZB26zvRZ1qvedOp9pOtXDU+7X2hZrJvAqraD+4ldz2pw0TNFQRTuRwq+cl+YB6Lq6bm+Y9tD4p5hx0wBXG27bbf21be/+kBN3UN1P1FgyraJY8jPfPuEAr8c09FoFOMbzqOCC9lsFsViEQxPJrCn46nggyUFJufRTZysvQxvAojJMfIFgFNrQ8dVveesdwnHzs4114+Cdl+0+xR4w7boGuM+orKYskvr0bWiMs0HYCoIpbJVw5jUc1Dr4HeFQiG
"text/plain": [
"<Figure size 1440x2160 with 25 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.rcParams['figure.figsize'] = [20, 30]\n",
"\n",
"dataiter = iter(testloader)\n",
"# dataiter = iter(trainloader)\n",
"images, labels = dataiter.next()\n",
"\n",
"with torch.no_grad():\n",
" output = model(images.to(device))\n",
"\n",
"# torch.set_printoptions(profile=\"full\")\n",
"# print(labels[0])\n",
"# torch.set_printoptions(profile=\"default\")\n",
"\n",
"for j in range(TEST_STEP):\n",
" out = output[j]\n",
" plt.subplot(TEST_STEP,5,j*5+1)\n",
" plt.imshow(images[j][0,:,:], cmap='gray')\n",
" \n",
" plt.subplot(TEST_STEP,5,j*5+2)\n",
" plt.imshow(labels[j][0,:,:], cmap='gray')\n",
" plt.subplot(TEST_STEP,5,j*5+3)\n",
" plt.imshow(out[0,:,:].cpu().detach().numpy(), cmap='gray')\n",
" \n",
" plt.subplot(TEST_STEP,5,j*5+4)\n",
" plt.imshow(labels[j][1,:,:], cmap='gray')\n",
" plt.subplot(TEST_STEP,5,j*5+5)\n",
" plt.imshow(out[1,:,:].cpu().detach().numpy(), cmap='gray')\n",
" \n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}