# Ultralytics YOLO ð, AGPL-3.0 license
import sys, os, torch, math, time, warnings
import torch_pruning as tp
import matplotlib
matplotlib.use('AGG')
import matplotlib.pylab as plt
import torch.nn as nn
from torch import optim
from thop import clever_format
from functools import partial
from torch import distributed as dist
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from datetime import datetime
from copy import copy, deepcopy
from pathlib import Path
import numpy as np
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel, attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, TQDM, clean_url, colorstr, emojis, yaml_save, callbacks, __version__
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
from ultralytics.utils.checks import check_imgsz, print_args, check_amp
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.torch_utils import ModelEMA, EarlyStopping, one_cycle, init_seeds, select_device
from ultralytics.utils.distill_loss import LogicalLoss, FeatureLoss
from ultralytics.nn.extra_modules.kernel_warehouse import get_temperature
def get_activation(feat, backbone_idx=-1):
def hook(model, inputs, outputs):
if backbone_idx != -1:
for _ in range(5 - len(outputs)): outputs.insert(0, None)
# for idx, i in enumerate(outputs):
# if i is None:
# print(idx, 'None')
# else:
# print(idx, i.size())
feat.append(outputs[backbone_idx])
else:
feat.append(outputs)
return hook
class DetectionDistiller(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionTrainer
args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
trainer = DetectionTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
overrides (dict, optional): Configuration overrides. Defaults to None.
"""
self.args = get_cfg(cfg, overrides)
self.check_resume(overrides)
self.device = select_device(self.args.device, self.args.batch)
self.validator = None
self.model = None
self.metrics = None
self.plots = {}
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
self.save_dir = get_save_dir(self.args)
self.args.name = self.save_dir.name # update name for loggers
self.wdir = self.save_dir / 'weights' # weights dir
if RANK in (-1, 0):
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.args.save_dir = str(self.save_dir)
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.save_period = self.args.save_period
self.batch_size = self.args.batch
self.epochs = self.args.epochs
self.start_epoch = 0
if RANK == -1:
print_args(vars(self.args))
# Device
if self.device.type in ('cpu', 'mps'):
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
self.model = self.args.model
try:
if self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
elif self.args.data.split('.')[-1] in ('yaml', 'yml') or self.args.task in ('detect', 'segment', 'pose'):
self.data = check_det_dataset(self.args.data)
if 'yaml_file' in self.data:
self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error â {e}")) from e
self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None
# Optimization utils init
self.lf = None
self.scheduler = None
# Epoch level metrics
self.best_fitness = None
self.fitness = None
self.loss = None
self.tloss = None
self.logical_disloss = None
self.feature_disloss = None
self.loss_names = ['Loss']
self.csv = self.save_dir / 'results.csv'
self.plot_idx = [0, 1, 2]
# Callbacks
self.callbacks = _callbacks or callbacks.get_default_callbacks()
if RANK in (-1, 0):
callbacks.add_integration_callbacks(self)
def build_dataset(self, img_path, mode='train', batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Construct and return dataloader."""
assert mode in ['train', 'val']
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode, batch_size)
shuffle = mode == 'train'
if getattr(dataset, 'rect', False) and shuffle:
LOGGER.warning("WARNING â ï¸ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
workers = self.args.workers if mode == 'train' else self.args.workers * 2
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
return batch
def set_model_attributes(self):
"""Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data['nc'] # attach number of classes to model
self.model.names = self.data['names'] # attach class names to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Returns a Detec
没有合适的资源?快使用搜索试试~ 我知道了~
YOLOv8知识蒸馏源码
共16个文件
pyc:9个
py:4个
md:1个
需积分: 0 8 下载量 75 浏览量
2024-03-27
12:41:09
上传
评论 1
收藏 1.34MB ZIP 举报
温馨提示
YOLOv8知识蒸馏源码
资源推荐
资源详情
资源评论
收起资源包目录
YOLOv8知识蒸馏源码.zip (16个子文件)
ultralytics
distill.md 33KB
get_FPS.py 3KB
ultralytics
utils
distill_loss.py 22KB
callbacks
__pycache__
__pycache__
models
__pycache__
__init__.cpython-38.pyc 296B
yolo
detect
__pycache__
train.cpython-38.pyc 7KB
predict.cpython-38.pyc 2KB
val.cpython-38.pyc 12KB
__init__.cpython-38.pyc 349B
distill.py 33KB
__pycache__
model.cpython-38.pyc 1KB
__init__.cpython-38.pyc 374B
cfg
__pycache__
__init__.cpython-38.pyc 17KB
default.yaml 8KB
__pycache__
__init__.cpython-38.pyc 711B
ultralytics-v8.1.9.zip 1.39MB
distill.py 1KB
共 16 条
- 1
资源评论
m0_51579041
- 粉丝: 7295
- 资源: 16
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功