Shortcuts

Source code for hfai.pl.utilities.to_hfai

import pytorch_lightning
import torch
from hfai.nn import to_hfai


[docs]def nn_to_hfai(lightningmodule: pytorch_lightning.LightningModule): """ 这是一个将算子转换为 `hfai` 算子的函数, 支持 `1.5.0 <= pytorch_lightning.__version__ <= 1.7.6` Args: lightningmodule (pytorch_lightning.LightningModule): 用于训练的 `pytorch_lightning` 模型类 Returns: lightningmodule (pytorch_lightning.LightningModule): 算子转换为 `hfai` 算子后的 `pytorch_lightning` 模型类 .. code-block:: python from hfai.pl import nn_to_hfai trainer = pytorch_lightning.Trainer( max_epochs=3, gpus=8, strategy="ddp_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa plugins=[HFAIEnvironment()] # 定义 Hfai 环境并作为插件输入 ) model_module = nn_to_hfai(ToyNetModule()) # 将算子转换为 hfai 算子 trainer.fit( model_module ) """ to_hfai_modules = [] for name, value in lightningmodule._modules.items(): if isinstance(value, torch.nn.Module): setattr(lightningmodule._modules, name, to_hfai(value)) to_hfai_modules.append(name) print(f'To hfai.nn modules: {to_hfai_modules}') return lightningmodule