Source code for hfai.pl.utilities.to_hfai
import pytorch_lightning
import torch
from hfai.nn import to_hfai
[docs]def nn_to_hfai(lightningmodule: pytorch_lightning.LightningModule):
"""
这是一个将算子转换为 `hfai` 算子的函数, 支持 `1.5.0 <= pytorch_lightning.__version__ <= 1.7.6`
Args:
lightningmodule (pytorch_lightning.LightningModule): 用于训练的 `pytorch_lightning` 模型类
Returns:
lightningmodule (pytorch_lightning.LightningModule): 算子转换为 `hfai` 算子后的 `pytorch_lightning` 模型类
.. code-block:: python
from hfai.pl import nn_to_hfai
trainer = pytorch_lightning.Trainer(
max_epochs=3,
gpus=8,
strategy="ddp_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa
plugins=[HFAIEnvironment()] # 定义 Hfai 环境并作为插件输入
)
model_module = nn_to_hfai(ToyNetModule()) # 将算子转换为 hfai 算子
trainer.fit(
model_module
)
"""
to_hfai_modules = []
for name, value in lightningmodule._modules.items():
if isinstance(value, torch.nn.Module):
setattr(lightningmodule._modules, name, to_hfai(value))
to_hfai_modules.append(name)
print(f'To hfai.nn modules: {to_hfai_modules}')
return lightningmodule