此 PR 源自https://github.com/pytorch/pytorch/pull/13070。
概括:此 PR 的目的是重构 ATen 中的随机数生成器 (RNG) 设计。目前,PyTorch 中的 RNG 具有不对称设计,即 CPU 生成器使用 ATen 类,而 CUDA 生成器使用旧版 THC 代码(THCRNGState, THCState, THCRandom_Init
等)。此外,从目前的设计来看,ATen 中的发电机概念并不明确。此 PR 是围绕 RNG 的更多重构工作的第一部分,目前仅处理 PyTorch 前端和 CPU 后端。它执行以下操作:
- 通过回顾 Generator、CPUGenerator 和 CUDAGenerator 类来阐明生成器概念。
- 将 mt19937 从 TH 移动到 aten 作为 MT19937RNGEngine.h,并将分布从 THRandom.cpp 移动到 DistributionsHelper.h。添加 PhiloxRNGEngine.h 引擎并为其添加单元测试。
- 修复了几个用于代码生成的 python 文件中硬编码的生成器相关代码,例如
function_wrapper.py
等。 - 修复了生成器前端 python 绑定以包括设备 kwarg 和默认 kwarg
- 从类型中删除生成器的创建。
- 更新文档和注释并添加
torch.Generator
api 文档。