Training the ML models
==================================
Before proceeding to training the ML models, do the following.
1) Define data_dir, graph_dir, log_path, and ray_dir in `paths.json` to convenient locations.
2) Run the following to unpack the processed eICU data into mmap files for easy loading during training. The mmap files will be saved in `data_dir`.
```
python3 -m src.dataloader.convert
```
The following commands train and evaluate the models introduced in our paper.
N.B.
- The models are structured using pytorch-lightning. Graph neural networks and neighbourhood sampling are implemented using pytorch-geometric.
- Our models assume a default graph which is made with k=3 under a k-closest scheme. If you wish to use other graphs, refer to `read_graph_edge_list` in `src/dataloader/pyg_reader.py` to add a reference handle to `version2filename` for your graph.
- The default task is **In-House-Mortality Prediction (ihm)**, add `--task los` to the command to perform the **Length-of-Stay Prediction (los)** task instead.
- These commands use the best set of hyperparameters; To use other hyperparameters, remove `--read_best` from the command and refer to `src/args.py`.
### a. LSTM-GNN
The following runs the training and evaluation for LSTM-GNN models. `--gnn_name` can be set as `gat`, `sage`, or `mpnn`. When `mpnn` is used, add `--ns_sizes 10` to the command.
```
python3 -m train_ns_lstmgnn --bilstm --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best
```
The following runs a hyperparameter search.
```
python3 -m src.hyperparameters.lstmgnn_search --bilstm --ts_mask --add_flat --class_weights --gnn_name gat --add_diag
```
### b. Dynamic LSTM-GNN
The following runs the training & evaluation for dynamic LSTM-GNN models. `--gnn_name` can be set as `gcn`, `gat`, or `mpnn`.
```
python3 -m train_dynamic --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn --read_best
```
The following runs a hyperparameter search.
```
python3 -m src.hyperparameters.dynamic_lstmgnn_search --bilstm --random_g --ts_mask --add_flat --class_weights --gnn_name mpnn
```
### c. GNN
The following runs the GNN models (with neighbourhood sampling). `--gnn_name` can be set as `gat`, `sage`, or `mpnn`. When `mpnn` is used, add `--ns_sizes 10` to the command.
```
python3 -m train_ns_gnn --ts_mask --add_flat --class_weights --gnn_name gat --add_diag --read_best
```
The following runs a hyperparameter search.
```
python3 -m src.hyperparameters.ns_gnn_search --ts_mask --add_flat --class_weights --gnn_name gat --add_diag
```
### d. LSTM (Baselines)
The following runs the baseline bi-LSTMs. To remove diagnoses from the input vector, remove `--add_diag` from the command.
```
python3 -m train_ns_lstm --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag --read_best
```
The following runs a hyperparameter search.
```
python3 -m src.hyperparameters.lstm_search --bilstm --ts_mask --add_flat --class_weights --num_workers 0 --add_diag
```
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
图神经网络 | Python实现LSTM-GNN时间序列预测 LSTM-GNN用于病人的结果预测:一个混合模型,结合了用于提取时间特征的长短期记忆网络(LSTM)和用于提取病人邻域信息的图谱神经网络(GNN)。 关于预测重症监护室(ICU)病人结果的工作主要集中在生理时间序列数据上,基本上忽略了诊断和药物等稀疏数据。当它们被包括在内时,它们通常是在模型的后期阶段被串联起来的,这可能难以从更罕见的疾病模式中学习。通过在图中连接类似的病人,将诊断作为关系信息加以利用。 LSTM-GNNs在eICU数据库的住院时间预测任务中的表现优于仅有LSTM的基线。利用图神经网络从相邻的病人病例中提取信息是一个很有前途的研究方向,在电子健康记录的监督学习性能方面产生了切实的回报。
资源推荐
资源详情
资源评论
收起资源包目录
GNN-LSTM程序数据包.zip (82个子文件)
eICU_preprocessing
labels.sql 3KB
split_train_test.py 3KB
timeseries.sql 10KB
flat_features.sql 2KB
run_all_preprocessing.py 754B
create_all_tables.sql 2KB
__init__.py 0B
timeseries.py 10KB
diagnoses.py 9KB
flat_and_labels.py 3KB
README.md 2KB
diagnoses.sql 2KB
graph_construction
get_diagnosis_strings.py 2KB
checking
__init__.py 0B
manual.py 903B
sanity.py 2KB
__init__.py 0B
bert.py 4KB
create_bert_graph.py 4KB
README.md 779B
create_graph.py 9KB
train_dynamic.py 12KB
paths.json 162B
注意事项.md 7KB
LICENSE 1KB
train_ns_gnn.py 11KB
src
models
utils.py 4KB
dgnn.py 4KB
pyg_ns.py 13KB
pyg_whole.py 4KB
__init__.py 0B
pyg_lstmgnn.py 3KB
lstm.py 4KB
utils.py 1KB
metrics.py 6KB
significance_testing
load_and_inspect.py 531B
print_latex.py 2KB
t-test.py 881B
__init__.py 0B
args.py 14KB
dataloader
ts_reader.py 4KB
__init__.py 0B
pyg_reader.py 5KB
convert.py 3KB
hyperparameters
best_parameters.py 6KB
search.py 6KB
ns_gnn_search.py 1KB
lstm_search.py 2KB
__init__.py 0B
lstmgnn_search.py 1KB
dynamic_lstmgnn_search.py 1KB
README.md 3KB
results
dyn_gat_ihm.csv 26KB
lstm_ihm_no_diag.csv 14KB
lstmgnn_gat_ihm.csv 18KB
lstmgnn_mpnn_ihm.csv 28KB
lstm_los_no_diag.csv 12KB
lstmgnn_sage_los.csv 15KB
lstmgnn_sage_ihm_no_diag.csv 19KB
lstm_los.csv 12KB
dyn_gcn_los.csv 12KB
ns_gat_los.csv 13KB
dyn_gat_los.csv 23KB
lstmgnn_gat_los.csv 17KB
lstmgnn_sage_ihm.csv 16KB
lstm_ihm.csv 14KB
ns_sage_ihm.csv 11KB
dyn_mpnn_ihm.csv 18KB
lstmgnn_mpnn_ihm_no_diag.csv 15KB
dyn_gcn_ihm.csv 18KB
lstmgnn_gat_los_no_diag.csv 14KB
dyn_mpnn_los.csv 14KB
ns_gat_ihm.csv 14KB
lstmgnn_mpnn_los.csv 16KB
lstmgnn_sage_los_no_diag.csv 13KB
lstmgnn_gat_ihm_no_diag.csv 17KB
ns_sage_los.csv 8KB
lstmgnn_mpnn_los_no_diag.csv 13KB
train_ns_lstmgnn.py 17KB
.gitignore 1KB
train_ns_lstm.py 14KB
requirements要求.txt 325B
共 82 条
- 1
资源评论
前程算法屋
- 粉丝: 4157
- 资源: 711
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功