注意采样-火炬
这是该论文的PyTorch实施: 。 该存储库基于用TensorFlow编写的本论文的。
移植到PyTorch
原始存储库中的代码已被重写为PyTorch 1.4.0实现。 最困难的部分是重写从高分辨率图像中提取补丁的功能。 原始版本为此使用了特殊的C / C ++文件,我已经在本地Python中完成了此操作。 由于可能需要嵌套的for循环,因此这可能效率更低,速度更慢。 我测试了并行执行补丁提取的过程,但这增加了很多开销,实际上它要慢一些。
此外,我希望我实现了正确计算期望值的部分。 这使用了一个自定义的backward()函数,我希望其中没有错误。
表现
此代码存储库已针对原始文件中提到的两项任务进行了测试:Mega-MNIST和交通标志检测任务。 对结果的定性分析表明它们与原始工作具有可比性,但是定性分析表明此代码库中的错误较高。 几个用户已经警告我,他们无法使
评论0
最新资源