# pytorch实用技巧
[**1. 得到模型参数数量**](#得到模型参数数量)
[**2. 特定网络结构参数分布初始化**](#特定网络结构参数分布初始化)
[**3. view函数**](#view函数)
[**4. unsqueeze函数**](#unsqueeze函数)
[**5. squeeze函数**](#squeeze函数)
[**6. pytorch自定义损失函数**](#pytorch自定义损失函数)
[**7. pytorch自定义矩阵W**](#pytorch自定义矩阵w)
[**8. 自定义操作torch.autograd.Function**](#autograd)
[**9. pytorch embedding设置不可导**](#pytorch_embedding设置不可导)
[**10. 中文tokenizer**](#中文tokenizer)
[**11. Accelerate: 适用于多GPU、TPU、混合精度训练**](#accelerate)
[**12. pytorch删除一层网络**](#pytorch删除一层网络)
---
## 得到模型参数数量
```python
def get_parameter_number(model):
total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
get_parameter_number(model)
```
## 特定网络结构参数分布初始化
```python
class AutoEncoder(nn.Module):
def __init__(self, feedback_bits):
super(AutoEncoder, self).__init__()
self.encoder = Encoder(feedback_bits)
self.decoder = Decoder(feedback_bits)
###-------初始化参数分布------###
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.xavier_uniform_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
###------------------------###
def forward(self, x):
feature = self.encoder(x)
out = self.decoder(feature)
return out
```
## view函数
```python
import torch as t
a = t.arange(0,6)
print(a)
# tensor([ 0., 1., 2., 3., 4., 5.])
a.view(2,-1) # 行数为2行,-1表示列自动计算
# tensor([[ 0., 1., 2.],
# [ 3., 4., 5.]])
```
## unsqueeze函数
```python
import torch as t
a = t.arange(0,6).view(2,3)
print(a)
# tensor([[ 0., 1., 2.],
# [ 3., 4., 5.]])
print(a.size())
# torch.Size([2, 3])
```
```python
# 有点像reshape
a.unsqueeze(0).size()
# torch.Size([1, 2, 3])
a.unsqueeze(1).size()
# torch.Size([2, 1, 3])
a.unsqueeze(2).size()
# torch.Size([2, 3, 1])
```
## squeeze函数
```python
import torch
a = torch.Tensor([[1,2,3]])
# tensor([[1., 2., 3.]])
a.squeeze()
# tensor([1., 2., 3.])
a = torch.Tensor([1,2,3,4,5,6])
a.view(2,3)
# tensor([[1., 2., 3.],
# [4., 5., 6.]])
a.squeeze()
# tensor([1., 2., 3., 4., 5., 6.])
```
## pytorch自定义损失函数
![nwrmsle.png](pic/nwrmsle.png)
```python
# pytorch自定义损失函数 Normalized Weighted Root Mean Squared Logarithmic Error(NWRMSLE)
# 这里y真实值需要提前进行log1p的操作
# 加入了sample_weights,和keras里model.fit(x,sample_weights)一样
from torch.functional import F
class my_rmseloss(nn.Module):
def __init__(self):
super(my_rmseloss, self).__init__()
return
def forward(self, input, target, sample_weights=None):
self._assert_no_grad(target)
f_revis = lambda a, b, w: ((a - b) ** 2) * w # 重写
return self._pointwise_loss(f_revis, torch._C._nn.mse_loss,
input, target, sample_weights)
# 重写_pointwise_loss
def _pointwise_loss(self, lambd, lambd_optimized, input, target, sample_weights):
if target.requires_grad:
d = lambd(input, target, sample_weights)
return torch.sqrt(torch.div(torch.sum(d), torch.sum(sample_weights)))
else:
if sample_weights is not None:
unrooted_res = torch.div(torch.sum(torch.mul(lambd_optimized(input, target),sample_weights)),torch.sum(sample_weights))
return torch.sqrt(unrooted_res)
return lambd_optimized(input, target, 1)
def _assert_no_grad(self, tensor):
assert not tensor.requires_grad, \
"nn criterions don't compute the gradient w.r.t. targets - please " \
"mark these tensors as not requiring gradients"
```
### pytorch自定义矩阵w
比如我在DigitCaps中定义了一个W的矩阵,想要这个矩阵可导,则用nn.Parameter包一下
```python
class DigitCaps(nn.Module):
def __init__(self, num_capsules=10, num_routes=32 * 40, in_channels=10, out_channels=16):
super(DigitCaps, self).__init__()
self.in_channels = in_channels
self.num_routes = num_routes
self.num_capsules = num_capsules
self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels),requires_grad=True) # 可导
def forward(self, x):
batch_size = x.size(0)
x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
W = torch.cat([self.W] * batch_size, dim=0)
u_hat = torch.matmul(W, x)
```
```python
# 把上面的加载进优化器就行了,如果这个DigitCaps在其他类中被调用,则
# 把最初始的那个main类加载入Adam就行
dcaps = DigitCaps()
optimizer = Adam(dcaps.parameters(),lr=0.001)
```
[和Keras build里面的self.add_weight是一样的](https://keras.io/zh/layers/writing-your-own-keras-layers/)
## autograd
[PyTorch 74.自定义操作torch.autograd.Function - 讲的很好](https://zhuanlan.zhihu.com/p/344802526)
### pytorch_embedding设置不可导
```python
self.encoder.weight = nn.Parameter(t.from_numpy(embedding_matrix).float(), requires_grad=False)
```
### 中文tokenizer
```python
import six
import unicodedata
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
co
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。关于机器学习的项目,新手可作为入门项目学习,欢迎下载使用。
资源推荐
资源详情
资源评论
收起资源包目录
模拟神经元功能和网络结构,来完成认知任务的一类机器学习算法.zip (332个子文件)
apiaccept 20KB
bwordcom 681B
detail-3c56c08d12.min.css 101KB
markdown_views-e44c3c0e64.css 69KB
content_toolbar.css 25KB
sandalstrap.min.css 16KB
ck_htmledit_views-3019150162.css 8KB
blog_code-c3a0c33d5c.css 5KB
pub_footer_1.0.3.css 4KB
share_style0_16.css 4KB
skin-yellow-2eefd34acf.min.css 3KB
side-toolbar.css 2KB
atom-one-light.css 1KB
paging-e040f0c7c8.css 1KB
chart-3456820cac.css 293B
20190406093237535.gif 173KB
feedLoading.gif 3KB
.gitignore 36B
.gitkeep 0B
.gitkeep 0B
.gitkeep 0B
.gitkeep 0B
pycuda-2020.1.tar.gz 1.57MB
hostfile 39B
001-TensorFlow 2.0 教程-Transformer - 知行_那片天 - CSDN博客.html 1.02MB
saved_resource.html 149B
2.object_detection_with_model_zoo.ipynb 2.33MB
002-DCGAN.ipynb 604KB
001-Transformer.ipynb 103KB
mojing_dssm.ipynb 41KB
我的第一个pytorch_cnn分类(11类别).ipynb 31KB
4.load_your_own_mxnet_bert.ipynb 28KB
mojing_lstm.ipynb 27KB
mojing3_conv1d.ipynb 24KB
4.load_your_own_pytorch_bert.ipynb 22KB
pytorch_distributed_textcnn_cpu.ipynb 20KB
1.mnist_demo_java.ipynb 13KB
run_pl.ipynb 13KB
gpu_gan_train.ipynb 13KB
pytorch多个进程模拟集群的分布式部署.ipynb 9KB
3.BERTQA.ipynb 7KB
gan_generate_pic.ipynb 6KB
inference.ipynb 5KB
1.tensorflow_andrewNg.ipynb 2KB
查看模型文件信息.jpg 90KB
0a438fc15ad624acddf1b2538c463330-0.jpg 5KB
common-37b7aadaf4.min.js 600KB
detail-1e5a65cde8.min.js 125KB
jquery-1.9.1.min.js 90KB
MathJax.js 62KB
pc_wap_common-f868939e52.js 56KB
content_toolbar.js 38KB
hm.js 35KB
share.js 17KB
collection-box.js 14KB
baidu_opensug-1.0.0.js 14KB
iconfont.js 13KB
main.js 13KB
publib_footer-1.0.3.js 11KB
paging-3d3b805766.js 8KB
bword.min.js 4KB
notify.js 4KB
side-toolbar.js 4KB
sandalstrap.min.js 3KB
linkCatcher-3a08af3a5f.js 1KB
counter.js 520B
skin-yellow-fc7383b956.min.js 255B
ds_config.json 1KB
ds_config.json 449B
macbert_training_monitor.json 116B
bert_training_monitor.json 116B
README.md 15KB
README.md 14KB
README.md 13KB
cnn_movanzhou.md 5KB
README.md 5KB
README.md 5KB
README.md 2KB
README.md 2KB
README.md 1KB
README.md 984B
README.md 528B
README.md 521B
README.md 434B
README.md 406B
README.md 341B
README.md 338B
README.md 232B
README.md 232B
README.md 232B
README.md 202B
README.md 192B
README.md 47B
README.md 45B
README.md 10B
result.png 1.01MB
crossentropyloss.png 118KB
20190507163127595.png 112KB
多机多卡env.png 75KB
20190507163058667.png 64KB
共 332 条
- 1
- 2
- 3
- 4
资源评论
c++服务器开发
- 粉丝: 3176
- 资源: 4461
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功