Source code for hfai.client.api.training_api
# @Author :hpp
# @CreatedAt :2021/5/7 16:21
import os
import sys
import time
from typing import Optional
import sysv_ipc
import itertools
import pickle
import socket
from hfai.base_model.base_task import BaseTask
from hfai.conf.flags import EXP_PRIORITY, WARN_TYPE
WATCHDOG_TIME_SHM_ID = 237965198
try:
WATCHDOG_TIME_SHM_ID_SHM = sysv_ipc.SharedMemory(WATCHDOG_TIME_SHM_ID, sysv_ipc.IPC_CREX, mode=0o777)
except sysv_ipc.ExistentialError:
WATCHDOG_TIME_SHM_ID_SHM = sysv_ipc.SharedMemory(WATCHDOG_TIME_SHM_ID)
try:
import pynvml
no_pynvml = False
except:
no_pynvml = True
# ================== 以下接口用于训练中调用 ======================================
IMPORT_TIME = time.time()
SIMULATE_SUSPEND_SEC = int(os.environ.get('SIMULATE_SUSPEND', -1))
IS_SIMULATE = os.environ.get('HFAI_SIMULATE', '0') == '1'
def nb_name() -> str:
return os.environ.get("MARSV2_NB_NAME", "NO_CLUSTER")
def task_id() -> str:
return os.environ.get("MARSV2_TASK_ID", -1)
def rank() -> int:
return int(os.environ.get("MARSV2_RANK", 0))
def node_name() -> str:
return os.environ.get('MARSV2_NODE_NAME', 'NAN_NODE')
def user_name() -> str:
return os.environ.get('MARSV2_USER', '')
def current_selector_task() -> BaseTask:
""" 获取调用 API 时用于选定任务的 task 实例, 仅有 nb_name 和 id 属性 """
task = BaseTask()
task.nb_name = nb_name()
task.id = int(task_id())
return task
def send_data(data, timeout: int = 500, raise_exception: bool = True):
"""
把 data 发送给 manager
socket的 backlog 为1024,超过1024的并发请求可能会变得很慢,请注意
Args:
data (dict):
timeout (int): 设置请求超时时间,默认为 500 秒
raise_exception (bool): 调用runtime接口时发生异常是否需要抛出,默认为不抛出
Returns:
bool: 表示是否通信成功
"""
b_data = pickle.dumps(data)
header = str(len(b_data) + 8).rjust(8).encode()
start_time = time.time()
result = None
while True:
if time.time() - start_time > timeout:
if raise_exception:
raise Exception(f'{data.get("source", "")}超时,请检查程序或联系管理员')
else:
print(f'{data.get("source", "")}超时,请检查程序或联系管理员')
return False
try:
waiting_time = max(1, timeout - int(time.time() - start_time))
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((f'{user_name()}-{task_id()}-manager-0', 7000))
s.settimeout(waiting_time)
s.send(header + b_data)
result = s.recv(1024)
except:
time.sleep(1)
if result is not None:
try:
result = pickle.loads(result)
success = result['success']
msg = result['msg']
except Exception as e:
if raise_exception:
raise Exception(f'解析返回值失败: {e},请联系管理员')
else:
print(f'解析返回值失败: {e},请联系管理员')
return False
if success == 0:
if raise_exception:
raise Exception(msg)
else:
print(msg)
return False
return True
[docs]def set_watchdog_time(seconds: int):
"""
设置任务超时时间,规定时间内无 log 该任务会被认为已失败,默认为 1800 秒
Args:
seconds (int): 超时时间,单位为秒
Examples:
>>> from hfai.client import set_watchdog_time
>>> set_watchdog_time(1800)
"""
if IS_SIMULATE:
print(f'模拟设置 watchdog time {seconds} 成功')
return
if nb_name() == 'NO_CLUSTER':
print('非集群环境')
return
assert isinstance(seconds, int), "传入的seconds应该是一个int型的"
assert len(str(seconds)) < 100, "传入的seconds位数应当小于100位"
WATCHDOG_TIME_SHM_ID_SHM.write(str(seconds).ljust(100))
return {
'success': 1,
'msg': f'set_watchdog_time {seconds}秒 设置成功'
}
[docs]def get_whole_life_state() -> Optional[int]:
"""
获取当前 chain_id 的上一个 id 任务留下来的 whole_life_state
Returns:
int: whole_life_state
Examples:
>>> from hfai.client import get_whole_life_state
>>> get_whole_life_state()
"""
if IS_SIMULATE:
return int(os.environ.get('MARSV2_WHOLE_LIFE_STATE', 0))
if nb_name() == 'NO_CLUSTER':
print('非集群环境')
return None
return int(os.environ.get('MARSV2_WHOLE_LIFE_STATE', 0))
[docs]def set_whole_life_state(state: int, timeout: int = 500, raise_exception: bool = True) -> bool:
"""
设置 whole_life_state
Args:
state (int): 想要设置的 whole_life_state
timeout (int): 设置请求超时时间,默认为 500 秒
raise_exception (bool): 调用runtime接口时发生异常是否需要抛出,默认为抛出
Examples:
>>> from hfai.client import set_whole_life_state
>>> set_whole_life_state(100)
"""
if IS_SIMULATE:
print(f'模拟设置 whole_life_state 成功,请下次调用的时候使用 --sls={state} 来设置生效')
return True
if nb_name() == 'NO_CLUSTER':
print('非集群环境')
return False
if os.environ['MARSV2_WHOLE_LIFE_STATE'] == str(state):
return False
data = {
'source': set_whole_life_state.__name__,
'whole_life_state': state
}
return send_data(data, timeout, raise_exception)
# ============== 任务优雅挂起 ===================================================
try:
import sysv_ipc # 注意,不兼容 windows
has_sys_ipc = True
except:
has_sys_ipc = False
SUSPEND_SHM_ID = 7123378543
if (nb_name() != 'NO_CLUSTER' or IS_SIMULATE) and has_sys_ipc:
SUSPEND_SHM_ID_SHM = sysv_ipc.SharedMemory(SUSPEND_SHM_ID,
sysv_ipc.IPC_CREAT, mode=0o777,
size=1)
else:
SUSPEND_SHM_ID_SHM = None
[docs]def receive_suspend_command(timeout: int = 500, raise_exception: bool = False) -> bool:
"""
获取该任务是否即将被打断
Args:
timeout (int): 设置请求超时时间,默认为 500 秒
raise_exception (bool): 调用runtime接口时发生异常是否需要抛出,默认为不抛出
Returns:
bool: 表示是否即将被打断
Examples:
>>> from hfai.client import receive_suspend_command
>>> receive_suspend_command()
"""
if IS_SIMULATE:
if 0 < SIMULATE_SUSPEND_SEC < (time.time() - IMPORT_TIME):
print('时间到了,触发模拟打断')
return True
return False
if SUSPEND_SHM_ID_SHM is None:
return False
if SUSPEND_SHM_ID_SHM.read() == b'1':
# 向 server 端报告需要知道我要被打断了
if nb_name() == 'NO_CLUSTER':
print('非集群环境')
return False
data = {
'source': receive_suspend_command.__name__
}
send_data(data, timeout=timeout, raise_exception=raise_exception)
return True
return False
[docs]def go_suspend(timeout: int = 500, raise_exception: bool = False):
"""
通知 server 该任务可以被打断
Args:
timeout (int): 设置请求超时时间,默认为 500 秒
raise_exception (bool): 调用runtime接口时发生异常是否需要抛出,默认为不抛出
Examples:
>>> from hfai.client import go_suspend
>>> go_suspend()
"""
# 防止 jupyter 容器误跑 go_suspend
if os.environ.get('MARSV2_TASK_TYPE', '') != 'training':
return
# 向 server 端报告需要知道我要被打断了
if IS_SIMULATE:
print('模拟打断成功,将退出进程')
sys.exit(0)
if nb_name() == 'NO_CLUSTER':
print('非集群环境')
return
data = {
'source': go_suspend.__name__
}
send_data(data, timeout=timeout, raise_exception=raise_exception)
for i in itertools.count(start=1):
time.sleep(10)
print(f'等了{i * 10}秒还没挂起,继续等待')
[docs]def set_priority(priority: int, timeout: int = 500, raise_exception: bool = False) -> bool:
"""
设置当前任务的优先级,注意如果你没有该优先级的权限可能会导致任务被立刻打断
Args:
priority (int): 设置的任务优先级
timeout (int): 设置请求超时时间,默认为 500 秒
raise_exception (bool): 调用runtime接口时发生异常是否需要抛出,默认为不抛出
Returns:
bool: 是否设置成功
Examples:
>>> from hfai.client import set_priority, EXP_PRIORITY
>>> set_priority(EXP_PRIORITY.LOW)
"""
# 向 server 端报告需要知道我要被打断了
if IS_SIMULATE:
print(f'模拟环境设置优先级 {priority} 成功')
return True
if nb_name() == 'NO_CLUSTER':
print('非集群环境')
return False
data = {
'source': set_priority.__name__,
'priority': priority
}
return send_data(data, timeout=timeout, raise_exception=raise_exception)
def disable_warn(warn_type: int, timeout: int = 500, raise_exception: bool = False) -> bool:
"""
静默warning报警
Args:
warn_type (int): 静默的报警类型,可以是WARN_TYPE的复合,0表示不静默任何报警
timeout (int): 设置请求超时时间,默认为 500 秒
raise_exception (bool): 调用runtime接口时发生异常是否需要抛出,默认为不抛出
Returns:
bool: 是否设置成功
Examples:
>>> from hfai.client import disable_warn, WARN_TYPE
>>> disable_warn(WARN_TYPE.LOG | WARN_TYPE.COMPLETED) # 日志超时以及completed超时不报警
"""
if IS_SIMULATE:
print(f'模拟静默warning报警成功')
return True
if nb_name() == 'NO_CLUSTER':
print('非集群环境')
return False
data = {
'source': disable_warn.__name__,
'warn_type': warn_type
}
return send_data(data, timeout=timeout, raise_exception=raise_exception)
def print_gpu_info(pid):
if no_pynvml or (not node_name().endswith('dl')):
return False
pynvml.nvmlInit()
for i in range(pynvml.nvmlDeviceGetCount()):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
print(f'[{pid}] gpu[{i}] memory total {info.total}, free {info.free}, used {info.used}; '
f'Power {pynvml.nvmlDeviceGetPowerState(handle)}')