"""
Train CLSL with AlexNet
This code refers to CMC:https://github.com/HobbitLong/CMC/#contrastive-multiview-coding
Author: Shaochi Hu
"""
import os
import sys
import time
import torch
import torch.backends.cudnn as cudnn
import argparse
import socket
import numpy as np
from torchvision import transforms, datasets
import torchvision
import tensorboard_logger as tb_logger
from dataset import ImageFolderInstance, ImageFolderInstance_LoadAllImgToMemory
from models.alexnet import MyAlexNetCMC
from NCE.NCEAverage import NCEAverage, E2EAverage
from NCE.NCECriterion import NCECriterion
from NCE.NCECriterion import NCESoftmaxLoss
from util import adjust_learning_rate, AverageMeter,print_running_time, Logger, check_pytorch_idx_validation, get_anchor_pos_neg
from sampleIdx import SampleIndex, RandomBatchSamplerWithPosAndNeg, RandomBatchSamplerWithSupplementPosAndNeg
from calculateSampleDis import calSampleDisAndImgCaseStudy
from onlyCalculdateFeat import onlyCalFeat
def parse_option():
parser = argparse.ArgumentParser('argument for training')
parser.add_argument('--print_freq', type=int, default=10, help='print every print_freq batchs')
parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency')
parser.add_argument('--save_freq', type=int, default=10, help='save model checkpoint every save_freq epoch')
parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
parser.add_argument('--num_workers', type=int, default=18, help='num of workers to use')
parser.add_argument('--epochs', type=int, default=240, help='number of training epochs')
parser.add_argument('--contrastMethod', type=str, default='e2e',choices=['e2e', 'membank'], help='method of contrast, e2e or membank')
# optimization
parser.add_argument('--learning_rate', type=float, default=0.03, help='learning rate')
parser.add_argument('--lr_decay_epochs', type=str, default='120,160,200', help='where to decay lr, can be a list')
parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam')
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
# resume path
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
# model definition
parser.add_argument('--model', type=str, default='alexnet', choices=['alexnet',
'resnet50v1', 'resnet101v1', 'resnet18v1',
'resnet50v2', 'resnet101v2', 'resnet18v2',
'resnet50v3', 'resnet101v3', 'resnet18v3'])
parser.add_argument('--softmax', action='store_true', help='using softmax contrastive loss rather than NCE')
parser.add_argument('--nce_k', type=int, default=4096) # negative sample number
parser.add_argument('--nce_t', type=float, default=0.07) # temperature parameter
parser.add_argument('--nce_m', type=float, default=0.5) # memory update rate
parser.add_argument('--feat_dim', type=int, default=128, help='dim of feat for inner product') # dimension of network's output
# specify folder
parser.add_argument('--data_folder', type=str, default=None, help='path to training data') # 训练数据文件夹,即锚点/正负样本文件夹
parser.add_argument('--test_data_folder', type=str, default=None, help='path to testing data') # 测试数据文件夹,即所有视频帧的文件夹
parser.add_argument('--model_path', type=str, default=None, help='path to save model')
parser.add_argument('--tb_path', type=str, default=None, help='path to tensorboard')
parser.add_argument('--log_txt_path', type=str, default=None, help='path to log file')
parser.add_argument('--result_path', type=str, default=None, help='path to sample dis and img case study') # 训练结束后,图像间距离的case study保存在这个路径下
# data crop threshold
parser.add_argument('--crop_low', type=float, default=0.8, help='low area in crop')
parser.add_argument('--comment_info', type=str, default='', help='Comment message, donot influence program')
parser.add_argument('--supplement_pos_neg_txt_path', type=str, default='')
parser.add_argument('--training_data_cache_method', type=str, default='default', choices=['default', 'memory', 'GPU'], help='where to save training data. \'memory\' or \'GPU\' will load all training data into memory or GPU at begining to speed up data reading at training stage.')
opt = parser.parse_args()
iterations = opt.lr_decay_epochs.split(',')
opt.lr_decay_epochs = list([])
for it in iterations:
opt.lr_decay_epochs.append(int(it))
opt.method = 'softmax' if opt.softmax else 'nce'
curTime = time.strftime("%Y%m%d_%H_%M_%S", time.localtime())
opt.model_name = '{}_lossMethod_{}_NegNum_{}_Model_{}_lr_{}_decay_{}_bsz_{}_featDim_{}_contrasMethod_{}_{}'.format(curTime, opt.method, opt.nce_k, opt.model, opt.learning_rate,
opt.weight_decay, opt.batch_size, opt.feat_dim, opt.contrastMethod, opt.comment_info)
if (opt.data_folder is None) or (opt.model_path is None) or (opt.tb_path is None) or (opt.log_txt_path is None) or (opt.result_path is None) or (opt.test_data_folder is None):
raise ValueError('one or more of the folders is None: data_folder | model_path | tb_path | log_txt_path | result_path | test_data_folder')
opt.model_folder = os.path.join(opt.model_path, opt.model_name)
if not os.path.isdir(opt.model_folder):
os.makedirs(opt.model_folder)
opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
if not os.path.isdir(opt.tb_folder):
os.makedirs(opt.tb_folder)
if not os.path.isdir(opt.log_txt_path):
os.makedirs(opt.log_txt_path)
opt.result_path = os.path.join(opt.result_path, opt.model_name)
if not os.path.isdir(opt.result_path):
os.makedirs(opt.result_path)
log_file_name = os.path.join(opt.log_txt_path, 'log_'+opt.model_name+'.txt')
sys.stdout = Logger(log_file_name) # 把print的东西输出到txt文件中
if opt.comment_info != '':
print('comment message : ', opt.comment_info)
print('start program at ' + time.strftime("%Y_%m_%d %H:%M:%S", time.localtime()))
print('Dataset :', opt.data_folder)
if not os.path.isdir(opt.data_folder):
raise ValueError('data path not exist: {}'.format(opt.data_folder))
if opt.contrastMethod != 'e2e' and opt.contrastMethod != 'membank':
raise ValueError('contrast method must be e2e or memory bank.')
return opt
def get_train_loader(args):
data_folder = os.path.join(args.data_folder, 'train')
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
train_transform_withRandom = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
transforms.RandomGrayscale(p=0.5),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
transforms.GaussianBlur(9, (0.1,3)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
transforms.RandomPerspective(distortion_scale=0.5, p=0.5)
])
train_transform_withoutRandom = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
# 选择训练数据读取方式
# default:传统方式,即每次仅读取一个batch的数据�