# Visualizing Convolutional Networks for MRI-based Diagnosis of Alzheimer’s Disease
**Johannes Rieke, Fabian Eitel, Martin Weygandt, John-Dylan Haynes and Kerstin Ritter**
Our paper was presented on the [MLCN workshop](https://mlcn2018.com/) at MICCAI 2018 in Granada ([Slides](https://drive.google.com/open?id=1EKHvlWq4_-NC7HQPAbZc_ZaeNZMTQwgh)).
**Preprint:** http://arxiv.org/abs/1808.02874
**Abstract:** Visualizing and interpreting convolutional neural networks (CNNs) is an important task to increase trust in automatic medical decision making systems. In this study, we train a 3D CNN to detect Alzheimer’s disease based on structural MRI scans of the brain. Then, we apply four different gradient-based and occlusion-based visualization methods that explain the network’s classification decisions by highlight- ing relevant areas in the input image. We compare the methods qualita- tively and quantitatively. We find that all four methods focus on brain regions known to be involved in Alzheimer’s disease, such as inferior and middle temporal gyrus. While the occlusion-based methods focus more on specific regions, the gradient-based methods pick up distributed rel- evance patterns. Additionally, we find that the distribution of relevance varies across patients, with some having a stronger focus on the temporal lobe, whereas for others more cortical areas are relevant. In summary, we show that applying different visualization methods is important to understand the decisions of a CNN, a step that is crucial to increase clinical impact and trust in computer-based decision support systems.
![Heatmaps](figures/heatmaps-ad.png)
## Quickstart
You can use the visualization methods in this repo on your own model (PyTorch; for other frameworks see below) like this:
from interpretation import sensitivity_analysis
from utils import plot_slices
cnn = load_model()
mri_scan = load_scan()
heatmap = sensitivity_analysis(cnn, mri_scan, cuda=True)
plot_slices(mri_scan, overlay=heatmap)
`heatmap` is a numpy array containing the relevance heatmap. The methods should work for 2D and 3D images alike. Currently, four methods are implemented and tested: `sensitivity_analysis`, `guided_backprop`, `occlusion`, `area_occlusion`. There is also a rough implementation of `grad_cam`, which seems to work on 2D photos, but not on brain scans. Please look at `interpretation.py` for further documentation.
## Code Structure
The codebase uses PyTorch and Jupyter notebooks. The main files for the paper are:
- `training.ipynb` is the notebook to train the model and perform cross validation.
- `interpretation-mri.ipynb` contains the code to create relevance heatmaps with different visualization methods. It also includes the code to reproduce all figures and tables from the paper.
- All `*.py` files contain methods that are imported in the notebooks above.
Additionally, there are two other notebooks:
- `interpretation-photos.ipynb` uses the same visualization methods as in the paper but applies them to 2D photos. This might be an easier introduction to the topic.
- `small-dataset.ipynb` contains some old code to run a similar experiment on a smaller dataset.
## Trained Model and Heatmaps
If you don't want to train the model and/or run the computations for the heatmaps yourself, you can just download my results: [Here](https://drive.google.com/file/d/14m6v9DOubxrid20BbVyTgOOVF-K7xwV-/view?usp=sharing) is the final model that I used to produce all heatmaps in the paper (as a pytorch state dict; see paper or code for more details on how the model was trained). And [here](https://drive.google.com/open?id=1feEpR-GhKUe_YTkKu9dlnYIKsyF6fyei) are the numpy arrays that contain all average relevance heatmaps (as a compressed numpy .npz file). Please have a look at `interpretations-mri.ipynb` for instructions on how to load and use these files.
## Data
The MRI scans used for training are from the [Alzheimer Disease Neuroimaging Initiative (ADNI)](http://adni.loni.usc.edu/). The data is free but you need to apply for access on http://adni.loni.usc.edu/. Once you have an account, go [here](http://adni.loni.usc.edu/data-samples/access-data/) and log in.
### Tables
We included csv tables with metadata for all images we used in this repo (`data/ADNI/ADNI_tables`). These tables were made by combining several data tables from ADNI. There is one table for 1.5 Tesla scans and one for 3 Tesla scans. In the paper, we trained only on the 1.5 Tesla images.
### Images
To download the corresponding images, log in on the ADNI page, go to "Download" -> "Image Collections" -> "Data Collections". In the box on the left, select "Other shared collections" -> "ADNI" -> "ADNI1:Annual 2 Yr 1.5T" (or the corresponding collection for 3T) and download all images. We preprocessed all images by non-linear registration to a 1 mm isotropic ICBM template via [ANTs](http://stnava.github.io/ANTs/) with default parameters, using the quick registration script from [here](https://github.com/ANTsX/ANTs/blob/master/Scripts/antsRegistrationSyNQuick.sh).
To be consistent with the codebase, put the images into the folders `data/ADNI/ADNI_2Yr_15T_quick_preprocessed` (for the 1.5 Tesla images) or `data/ADNI/ADNI_2Yr_3T_preprocessed` (for the 3 Tesla images). Within these folders, each image should have the following path: `<PTID>/<Visit (spaces removed)>/<PTID>_<Scan.Date (/ replaced by -)>_<Visit (spaces removed)>_<Image.ID>_<DX>_Warped.nii.gz`. If you want to use a different directory structure, you need to change the method `get_image_filepath` and/or the filenames in `datasets.py`.
### Users from Ritter/Haynes lab
If you're working in the Ritter/Haynes lab at Charité Berlin, you don't need to download any data, but simply uncomment the correct `ADNI_DIR` variable in `datasets.py`.
## Requirements
- Python 2 (mostly compatible with Python 3 syntax, but not tested)
- Scientific packages (included with anaconda): numpy, scipy, matplotlib, pandas, jupyter, scikit-learn
- Other packages: tqdm, tabulate
- PyTorch: torch, torchvision (tested with 0.3.1, but mostly compatible with 0.4)
- torchsample: I made a custom fork of torchsample which fixes some bugs. You can download it from https://github.com/jrieke/torchsample or install directly via `pip install git+https://github.com/jrieke/torchsample`. Please use this fork instead of the original package, otherwise the code will break.
## Non-pytorch Models
If your model is not in pytorch, but you still want to use the visualization methods, you can try to transform the model to pytorch ([overview of conversion tools](https://github.com/ysh329/deep-learning-model-convertor)).
For keras to pytorch, I can recommend [nn-transfer](https://github.com/gzuidhof/nn-transfer). If you use it, keep in mind that by default, pytorch uses channels-first format and keras channels-last format for images. Even though nn-transfer takes care of this difference for the orientation of the convolution kernels, you may still need to permute your dimensions in the pytorch model between the convolutional and fully-connected stage (for 3D images, I did `x = x.permute(0, 2, 3, 4, 1).contiguous()`). The safest bet is to switch keras to use channels-first as well, then nn-transfer should handle everything by itself.
## Citation
If you use our code, please cite our [paper](http://arxiv.org/abs/1808.02874):
@inproceedings{rieke2018,
title={Visualizing Convolutional Networks for MRI-based Diagnosis of Alzheimer's Disease},
author={Rieke, Johannes and Eitel, Fabian and Weygandt, Martin and Haynes, John-Dylan and Ritter, Kerstin},
booktitle={Machine Learning in Clinical Neuroimaging (MLCN)},
year={2018}
}
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
本项目旨在利用基于图像的MRI数据实现阿尔茨海默病的诊断。阿尔茨海默病是一种常见的神经系统疾病,早期诊断对于患者的治疗和生活质量至关重要。 我们采用深度学习算法,通过分析MRI图像,实现对阿尔茨海默病的自动识别和分类。项目使用的数据集包括公开的MRI数据集,如ADNI、AIBL等,并进行了预处理,包括图像增强、分割和特征提取等。 在运行环境方面,我们使用Python编程语言,基于TensorFlow、PyTorch等深度学习框架进行开发。为了提高计算效率,我们还使用了GPU加速计算。此外,我们还采用了Docker容器技术,确保实验结果的可重复性。 项目完成后,将实现对阿尔茨海默病的早期、准确诊断,为患者提供更好的治疗和生活质量。同时,项目成果也可应用于其他神经退行性疾病的诊断和研究。
资源推荐
资源详情
资源评论
收起资源包目录
基于图像MRI的阿尔茨海默病诊断内含数据集和运行说明.zip (26个子文件)
small-dataset.ipynb 1.3MB
utils.py 7KB
interpretation-photos.ipynb 1.16MB
data
aal.nii.gz 160KB
WM_GM_mask.nii.gz 128KB
ADNI
ADNI_2Yr_3T_preprocessed
.gitkeep 0B
ADNI_2Yr_15T_quick_preprocessed
.gitkeep 0B
ADNI_tables
customized
DxByImgClean_CompleteAnnual2YearVisitList_1_5T.csv 196KB
DxByImgClean_CompleteAnnual2YearVisitList_3T.csv 31KB
photos
dog
437202643_e32ce43baa.jpg 73KB
484566503_ae0db5014e.jpg 31KB
446455888_1c8fd6c7b9.jpg 74KB
334723839_fdbb6ccd8e.jpg 104KB
fungus
2814222523_5b5e63edaf.jpg 219KB
823947122_b69d92304c.jpg 127KB
436311352_7de6621c3d.jpg 149KB
2585629412_224f90b9b2.jpg 160KB
binary_brain_mask.nii.gz 242KB
models.py 7KB
interpretation.py 24KB
figures
heatmaps-nc.png 668KB
heatmaps-ad.png 653KB
datasets.py 11KB
interpretation-mri.ipynb 10.92MB
README.md 8KB
training.ipynb 742KB
共 26 条
- 1
资源评论
AI拉呱
- 粉丝: 2848
- 资源: 5448
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- SpringBoot3.3 实现停止/重启定时任务的代码
- vgg-generated-120.i
- stk8329 acc datasheet
- CT7117体温传感器驱动代码
- 基于51单片机和HC-05蓝牙模块、Lcd模块、DS18B20温度传感器模块利用串口通信进行环境监测源码全部资料(高分项目)
- MID国家编码表 MMSI国家编码表 MMSI-MID 国家编码表 AIS 国家编码表
- 基于51单片机和HC-05蓝牙模块、Lcd模块、DS18B20温度传感器模块利用串口通信进行环境监测(完整高分项目代码)
- c05300 amoled datasheet
- ats3089 datasheet
- 矩芯 ats3085s datasheet
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功