Alpha冲刺2/3 进展及体会

一、项目进展

  • 本周的进展主要是前端界面进行了美化,并且重新设计了一下功能

  • 这我主要负责了把训练好的模型应用到后端,输入图片的路径就可以输出检测的结果

    1. 首先,研究了SAFNet的模型训练代码

      1
      2
      3
      4
      5
      data_path = './data/SAFNet'
      data_traingt = sio.loadmat(os.path.join(data_path, 'mask_train.mat'))['mask_train']
      data_testgt = sio.loadmat(os.path.join(data_path, 'mask_test.mat'))['mask_test']
      im1 = sio.loadmat(os.path.join(data_path, 'data_1.mat'))['data']
      im2 = sio.loadmat(os.path.join(data_path, 'data_2.mat'))['data']

      这里需要注意,读入数据时的通道数是1维,所以当在网页的后端读取bmp图片的格式,要注意把3通道的数据读入到一个通道里面

      1
      2
      im1 = io.imread(os.path.join(origin, path1))[:, :, 0].astype(np.float32)
      im2 = io.imread(os.path.join(origin, path2))[:, :, 0].astype(np.float32)

      然后就是利用输入的图像来创建Patch,并利用创建的Patch来对模型进行训练,这里创建的Patch数量等于gt中元素大于1的个数

      1
      2
      3
      # 创建Patch:数量等于gt中元素大于1的个数
      train_1, labels ,_ = createImgCube(im1, data_traingt, createPosWithoutZero(im1, data_traingt), windowSize=windowSize)
      train_2, _ ,_ = createImgCube(im2, data_traingt, createPosWithoutZero(im2, data_traingt), windowSize=windowSize)

      再之后做的就是数据增强,并且按照一定的比例划分训练集和测试集,最后在生成检测图片的时候,同样也是按照Patch进行生成的,先对小的区域进行检测,最后再把小的区域总共合起来

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      for i in range(height):
      for j in range(width):
      # if preclassify_lab[i, j]!= 1.5:
      # outputs[i, j] = preclassify_lab[i, j]

      # else:
      patch1 = im1[i:i+windowSize, j:j+windowSize, :]
      patch1 = patch1.reshape(1, patch1.shape[0], patch1.shape[1], patch1.shape[2])
      X_test_image = torch.FloatTensor(patch1.transpose(0, 3, 1, 2)).to(device)

      patch2 = im2[i:i+windowSize, j:j+windowSize, :]
      patch2 = patch2.reshape(1, patch2.shape[0], patch2.shape[1], patch2.shape[2])
      X_test_image1 = torch.FloatTensor(patch2.transpose(0, 3, 1, 2)).to(device)

      _, _, prediction = model(X_test_image, X_test_image1)
      prediction = np.argmax(prediction.detach().cpu().numpy(), axis=1)
      outputs[i][j] = prediction
      if i % 20 == 0:
      print('... ... row ', i, ' handling ... ...')
    2. 其次,利用保存的模型参数在本地编写了检测SAR图片的脚本,之后后端只需要调用该函数便可以得到检测出的结果

    1. 图片的检测结果如下所示

二、心得体会

这周仔细研究了论文和代码之后,收获还是很大的,理解了大部分的代码,并且自己也利用了保存的模型实现了检测的效果。总而言之,下周继续加油吧。