# @Author :hpp
# @CreatedAt :2020/12/22 17:46
# 萤火2号 api
import os
from abc import ABC
from io import StringIO
from typing import Tuple, List, Union
import datetime
import munch
from hfai.base_model.base_task import BasePod
from hfai.base_model.training_task import TrainingTask, ITrainingTaskImpl
# client
from hfai.conf.flags import STATUS_COLOR_MAP, EXP_PRIORITY, TASK_TYPE
from rich import box
from rich.table import Table
import urllib
from .api_config import get_mars_token as mars_token
from .api_config import get_mars_url as mars_url
from .api_utils import async_requests, RequestMethod
from .training_api import task_id
# ==============================================================================
class Experiment(TrainingTask):
"""
任务类
包含如下属性:
- id (int): 任务 id
- nb_name (str): 任务名
- user_name (str): 用户名
- code_file (str): 训练任务代码的路径
- workspace (str): 训练任务代码的 workspace
- config_json (dict): 任务的配置信息,包括:priority (`int`),environment (`dict[str, str]`),whole_life_state (`int`)
- group (str): 任务所在组
- nodes (int): 任务占用节点数量
- assigned_nodes (list[str]): 分配的节点
- whole_life_state (int): 当前设置的 whole_life_state
- star(bool): 是否是星标任务
- first_id (int): 整个 chain_id 中最小的 id
- backend (str): 任务所在环境
- task_type (str): 任务类型
- queue_status (str): 任务当前运行状态
- priority (int): 任务当前的优先级
- chain_id (str): 任务 chain_id
- stop_code (int): 任务退出情况
- worker_status (str): 任务结束时的状态
- begin_at (str): 任务开始时间
- end_at (str): 任务结束时间
- created_at (str): 任务创建时间
- id_list (list[int]): 整个 chain_id 的所有 id
- begin_at_list (list[str]): 整个 chain_id 所有 id 的启动时间
- end_at_list (list[str]): 整个 chain_id 所有 id 的结束时间
- stop_code_list (list[int]): 整个 chain_id 所有 id 的退出情况
- whole_life_state_list (list[int]): 整个 chain_id 所有 id 的最新 whole_life_state
- _pods_ (list[Pod]): 该任务每个 pod 的各项参数
Examples:
.. code-block:: python
from hfai.client import get_experiment
import asyncio
experiment: Experiment = asyncio.run(get_experiment(id=1))
log = asyncio.run(experiment.log_ng(rank=0)) # 获取 rank0 的日志
asyncio.run(experiment.stop()) # 结束该任务
"""
experiment_columns = [
'id', 'nb_name', 'nodes', 'chain_status', 'task_type', 'suspend_count', 'created_at'
]
def __init__(self, implement_cls=None, **kwargs):
super(Experiment, self).__init__(implement_cls, **kwargs)
self.suspend_count = self.restart_count
self._pods_ = [BasePod(**pod) for pod in self._pods_]
self.last_seen = None
async def set_priority(self, priority: int, **kwargs):
return await self.update(('priority', ), (priority, ))
async def rerun(self, *args, **kwargs):
task = self
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/rerun_task?token={token}&chain_id={task.chain_id}'
result = await async_requests(RequestMethod.POST, url, [1, 2])
return Experiment(ExperimentImpl, **result['task'])
def row(self):
values = []
for k in self.experiment_columns:
v = self.__getattribute__(k)
color = STATUS_COLOR_MAP.get(v, 'white')
values.append(f'[{color}]{v}[/{color}]')
return values
def tables(self):
experiment_table = Table(show_header=True, box=box.ASCII_DOUBLE_HEAD)
for k in self.experiment_columns:
experiment_table.add_column(k)
experiment_table.add_row(*self.row())
job_table = Table(show_header=True, box=box.ASCII_DOUBLE_HEAD)
job_columns = ['rank', 'status', 'node', 'begin_at']
for k in job_columns:
job_table.add_column(k)
for p in self.pods:
values = []
if isinstance(p, dict):
p = BasePod(**p)
p.rank = p.job_id
for c in job_columns:
v = p.__getattribute__(c)
color = STATUS_COLOR_MAP.get(v, 'white')
values.append(f'[{color}]{v}[/{color}]')
job_table.add_row(*values)
return experiment_table, job_table
def get_running_time_in_seconds(self) -> float:
"""
获取任务总运行时间, 包括同一 chain 的历史任务.
@return: 任务运行时间, 以秒为单位.
"""
if min(len(self.begin_at_list), len(self.end_at_list)) == 0:
raise Exception('当前实例仅包含单个子任务的信息,'
'请通过 `hfai.client.api.experiment_api.get_experiment` 等查询接口获取包含完整 chain tasks 信息的实例.')
parse = lambda t: datetime.datetime.strptime(t, '%Y-%m-%dT%H:%M:%S.%f')
begin_at_list = list(map(parse, self.begin_at_list))
end_at_list = list(map(parse, self.end_at_list))
if self.queue_status == 'scheduled':
end_at_list[-1] = datetime.datetime.now()
total_time = datetime.timedelta(0)
for begin_at, end_at in zip(begin_at_list, end_at_list):
total_time += end_at - begin_at
return total_time.total_seconds()
class ExperimentImpl(ITrainingTaskImpl, ABC):
def select_pods(self, *args, **kwargs):
# pods 应该是 server 端直接返回的,所以这里是不需要的
pass
async def update(self, fields: Tuple[str], values: Tuple, *args, **kwargs):
assert fields[0] == 'priority', 'client 只能更新 priority'
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/task/priority/update?token={token}&chain_id={task.chain_id}&priority={values[0]}'
await async_requests(RequestMethod.POST, url, [1, 2])
return True
async def stop(self, op='stop', *args, **kwargs):
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/task/stop?token={token}&chain_id={task.chain_id}&op={op}'
await async_requests(RequestMethod.POST, url, [1, 2])
return True
async def suspend(self, restart_delay: int = 0, *args, **kwargs):
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/task/suspend?token={token}&chain_id={task.chain_id}&restart_delay={restart_delay}'
await async_requests(RequestMethod.POST, url, [1])
return True
async def log(self, rank: int = 0, last_seen: str = 'null', with_code=False, *args, **kwargs):
"""
查看日志, 获取日志的时候,同样会返回状态,这样就可以一同刷新掉了
@param rank:
@param last_seen:
@param with_code: True, 返回 (log, exit_code, stop_code) 这个 tuple
@return:
"""
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/query/task/log?token={token}&chain_id={task.chain_id}&rank={rank}&last_seen={last_seen}'
res = await async_requests(RequestMethod.POST, url, [1])
self.task.last_seen = res['last_seen']
if with_code:
return res['data'], res["exit_code"], res["stop_code"]
else:
return res['data']
async def log_ng(self, rank: int = 0, last_seen: str = 'null', *args, **kwargs):
"""
查看日志, 获取日志的时候,同样会返回状态,这样就可以一同刷新掉了
@param rank:
@param last_seen:
@return:
"""
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/query/task/log?token={token}&chain_id={task.chain_id}&rank={rank}&last_seen={last_seen}'
res = await async_requests(RequestMethod.POST, url, [1])
self.task.last_seen = res['last_seen']
return res
async def sys_log(self, *args, **kwargs):
"""
查看系统错误日志
:return:
"""
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/query/task/sys_log?token={token}&chain_id={task.chain_id}'
res = await async_requests(RequestMethod.POST, url, [1])
return res['data']
async def search_in_global(self, content, *args, **kwargs):
"""
全局搜索该任务每个rank包含content的次数
:param content:
:param args:
:param kwargs:
:return: 返回一个list,表示每个rank包含content的次数
"""
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/query/task/log/search?token={token}&chain_id={task.chain_id}&{urllib.parse.urlencode({"content": content})}'
res = await async_requests(RequestMethod.POST, url, [1])
return res['data']
async def tag_task(self, tag: str, *args, **kwargs):
"""
给当前任务添加标签
:param tag:
:param args:
:param kwargs:
:return:
"""
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/task/tag?token={token}&chain_id={task.chain_id}&tag={tag}'
res = await async_requests(RequestMethod.POST, url, [1])
return res['msg']
async def untag_task(self, tag: str, *args, **kwargs):
"""
给当前任务删除标签
:param tag:
:param args:
:param kwargs:
:return:
"""
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/task/untag?token={token}&chain_id={task.chain_id}&tag={tag}'
res = await async_requests(RequestMethod.POST, url, [1])
return res['msg']
async def map_task_artifact(self, artifact_name: str, artifact_version: str, direction: str, *args, **kwargs):
"""
设置当前任务制品信息
:param artifact_name: 制品命名
:param artifact_version: 制品版本
:param input: 是否为任务输入制品,False表示任务输出制品
"""
assert artifact_name != '', 'artifact_name 不能为空'
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/task/artifact/map?token={token}&chain_id={task.chain_id}&artifact_name={artifact_name}&artifact_version={artifact_version}&direction={direction}'
res = await async_requests(RequestMethod.POST, url, [1])
return res['msg']
async def unmap_task_artifact(self, direction: str, *args, **kwargs):
"""
删除当前任务制品信息
"""
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/task/artifact/unmap?token={token}&chain_id={task.chain_id}&direction={direction}'
res = await async_requests(RequestMethod.POST, url, [1])
return res['msg']
async def get_task_artifact(self, *args, **kwargs):
"""
获取当前任务制品信息
"""
task = self.task
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/query/task/artifact/get?token={token}&chain_id={task.chain_id}'
res = await async_requests(RequestMethod.POST, url, [1])
return res['msg']
async def get_latest_point(self, *args, **kwargs): # get_experiment_perf_current
"""获取任务当前的性能监控"""
task = self.task
token = kwargs.get('token', mars_token)
url = f'{mars_url()}/monitor_v2/task_perf_api?token={token}&chain_id={task.chain_id}'
result = await async_requests(RequestMethod.POST, url)
return result['data']
async def get_chain_time_series(self, query_type: str, rank: int = None, *args, **kwargs):
"""获取整条chain的时序性能数据"""
task = self.task
data_interval = kwargs.get('data_interval', '5min')
assert query_type in ('gpu', 'cpu', 'mem', 'every_card', 'every_card_mem')
token = kwargs.get('token', mars_token)
url = f'{mars_url()}/monitor/task/chain_perf_series?token={token}&chain_id={task.chain_id}&typ={query_type}&rank={rank}&data_interval={data_interval}'
result = await async_requests(RequestMethod.POST, url)
return result['data']
# ==============================================================================
[docs]async def get_experiments(
page: int,
page_size: int,
only_star=False,
select_pods=True,
nb_name_pattern=None,
task_type_list=['training', 'virtual', 'background'],
worker_status_list=[],
queue_status_list=[],
tag_list=[],
**kwargs
):
"""
获取自己最近提交的任务
Args:
page (int): 第几页
page_size (int): 每一页的任务个数
only_star (bool): 只考虑 ``star`` 的任务(默认为 ``False``)
select_pods(bool): 是否查询 pod
nb_name_pattern (str): 查询 nb_name 带有这个字符串的任务
task_type_list (list[str]): 查询 task_type,默认拿 training 和 validation
worker_status_list (list[str]): 查询 worker_status
queue_status_list (list[str]): 查询 queue_status
tag_list: 查询 tag
Returns:
int, list[Experiment]: 符合条件的任务总数,返回的任务列表
Examples:
>>> from hfai.client import get_experiments
>>> import asyncio
>>> asyncio.run(get_experiments(page=1, page_size=10)) # python3.8以下可能不支持asyncio.run的用法,需要用其它异步调用接口
"""
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/query/task/list?page={page}&page_size={page_size}&token={token}'
for task_type in task_type_list:
url += f'&task_type={task_type}'
if nb_name_pattern is not None:
url += f'&nb_name_pattern={nb_name_pattern}'
for worker_status in worker_status_list:
url += f'&worker_status={worker_status}'
for queue_status in queue_status_list:
url += f'&queue_status={queue_status}'
if only_star:
tag_list.append('start')
for tag in tag_list:
url += f'&tag={tag}'
url += f'&select_pods={select_pods}'
result = (await async_requests(RequestMethod.POST, url))['result']
total = result['total']
tasks = result['tasks']
return total, [Experiment(ExperimentImpl, **t) for t in tasks]
[docs]async def get_experiment(name: str = None, id: int = None, chain_id: str = None, **kwargs):
"""
通过 name、id 或 chain_id 获取训练任务,不能都为空,只能获取自己的任务
Args:
name (str): 任务名
id (int): 任务 id
chain_id (str): 任务 chain_id
Returns:
Experiment: 返回的任务
Examples:
>>> from hfai.client import get_experiment
>>> import asyncio
>>> asyncio.run(get_experiment(id=1)) # python3.8以下可能不支持asyncio.run的用法,需要用其它异步调用接口
"""
nb_name = name
if nb_name is None and id is None and chain_id is None:
id = task_id()
assert id is not None, '非集群环境必须设置一个 nb_name/id/chain_id'
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/query/task?token={token}&'
if id is not None:
url += f'id={id}'
elif nb_name is not None:
url += f'nb_name={nb_name}'
else: # chain_id is not None:
url += f'chain_id={chain_id}'
result = await async_requests(RequestMethod.POST, url)
return Experiment(ExperimentImpl, **result['result']['task'])
async def create_experiment(config: Union[str, StringIO, munch.Munch], **kwargs) -> Experiment:
"""
根据 v2 配置文件创建任务
配置文件示例:
.. code-block:: yaml
version: 2
name: test_create_experiment
priority: 20 # 可选,内部用户 50 40 30 20, 外部用户 0, 不填为 -1
spec: # 任务定义,根据定义,将在集群上做下面的运行
# cd /xxx/xxx; YOUR_ENV_KEY=YOUR_ENV_KEY python xxx.py --config config
workspace: /xxx/xxx # 必填
entrypoint: xxx.py # 必填, 若 entrypoint_binary 为 False 或者不填,那么支持 .py 或者 .sh, .sh 则使用 bash xxx.sh 运行;
# 若 entrypoint_binary 为 True,那么认为 entrypoint 是可执行文件,直接使用 <entrypoint> 运行
parameters: --config config # 可选
environments: # 可选
YOUR_ENV_KEY: YOUR_ENV_VALUE
entrypoint_executable: False # 可选,不填则默认为 False,若为 True,那么认为 entrypoint 是可执行文件
resource:
image: registry.high-flyer.cn/hfai/docker_ubuntu2004:20220630.2 # 可选,不指定,默认 default,通过 hfai 上传的 image,或者集群内建的 template
group: jd_a100#heavy # 可选, jd_a100, jd_a100#heavy, jd_a100#light, jd_a100#A, jd_a100#B
node_count: 1 # 必填
services: # 可选,自定义服务
- name: custom
port: 8123
type: http
rewrite_uri: true
options: # 可选
whole_life_state: 1 # hfai.get_whole_life_state() => 1
mount_code: 2 # use 3fs prod mount
py_venv: 202111 # 会在运行脚本前,source 一下 python 环境,根据输入不同选择 hf_env 或 hfai_env。
# 分为两类:1. 202111 => source haienv 202111; 2.1 hfai_env_name[hfai_env_owner] => source haienv hfai_env_name -u hfai_env_owner
# 2.2 hfai_env_name => source haienv hfai_env_name
# hf_env 可选: 202105, 202111, 202207, 其中202111会根据镜像选择py3.6或者py3.8
override_node_resource: # 覆盖默认的resource选项
cpu: 0
memory: 0
Args:
config (str, StringIO, munch.Munch): 配置路径,yaml 的 string,或 Munch
Returns:
Experiment: 生成的任务
Examples:
.. code-block:: python
from hfai.client import create_experiment
import asyncio
asyncio.run(create_experiment('config/path')) # python3.8以下可能不支持asyncio.run的用法,需要用其它异步调用接口
await create_experiment('''
version: 2
name: test_create_experiment
priority: 20
... yaml file
''')
"""
if isinstance(config, str):
config_file = os.path.expanduser(config)
if os.path.exists(config_file):
config = munch.Munch.fromYAML(open(config_file))
else:
config = munch.Munch.fromYAML(StringIO(config))
elif isinstance(config, StringIO):
config = munch.Munch.fromYAML(config)
elif isinstance(config, munch.Munch):
config = config
else:
assert 0, '非法输入'
# 校验, 这边就做最简单的检查
def check_exist(key):
value = config.copy()
ks = key.split('.')
for _k in ks[:-1]:
value = value.get(_k, {})
assert value.get(ks[-1], None) is not None, f'配置项 {key} 必须存在'
keys = ['version', 'name']
for k in keys:
check_exist(k)
assert int(config.version) == 2, '版本出错,v2 接口需要 v2 配置文件'
assert len(config.name) <= 511, f'name 长度不应超过 511'
# 检查 profile 参数
profile = config.get('options', {}).get('profile')
if profile:
now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
profile['log_dir'] = '${MARSV2_LOG_DIR}/haiprof/' + now
seconds = profile.get('time', 0)
assert isinstance(seconds, int) and seconds >= 0, "profile 的时间必须是 >= 0 的整数"
for k in profile.get('interval', {}):
interval = profile['interval'][k]
assert isinstance(interval, int) and interval >= 1, "采样周期必须是 >= 1 的整数"
token = kwargs.get('token', mars_token())
result = await async_requests(RequestMethod.POST, url=f'{mars_url()}/operating/task/create?token={token}',
assert_success=[1, 2], json=config.__dict__)
return Experiment(ExperimentImpl, **result['task'])
async def _post_validate(url):
result = await async_requests(RequestMethod.POST, url)
if result['created']:
return {
'success': 1,
'msg': f"{result['msg']},任务编号: {result['task']['id']}"
}
else:
return {
'success': 0,
'msg': result['msg']
}
async def validate_nodes(nodes: List[str], file: str = '/marsv2/scripts/validation/validate.sh', backend: str = 'cuda_11', **kwargs):
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/node/validate?token={token}&file={file}&backend={backend}&nodes={",".join(nodes)}'
job_table = await _post_validate(url)
return job_table
async def validate_experiment(name: str = None, id: int = None, chain_id: str = None, ranks: Tuple = (), file: str = '/marsv2/scripts/validation/validate.sh', backend: str = 'cuda_11', **kwargs):
nb_name = name
assert not(nb_name is None and id is None and chain_id is None), '必须设置一个 nb_name/id/chain_id'
all_rank = any([rank == 'all' for rank in ranks]) or not ranks
if not all_rank:
for rank in ranks:
assert isinstance(rank, int) or rank.isnumeric(), '输入的rank必须是个整数'
chosen_ranks = ','.join([str(rank) for rank in ranks])
else:
chosen_ranks = 'all'
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/operating/task/validate?chosen_ranks={chosen_ranks}&token={token}&file={file}&backend={backend}&'
if id is not None:
url += f'id={id}'
elif nb_name is not None:
url += f'nb_name={nb_name}'
else: # chain_id is not None:
url += f'chain_id={chain_id}'
job_table = await _post_validate(url)
return job_table
async def get_task_container_log(id, rank, **kwargs):
token = kwargs.get('token', mars_token())
url = f'{mars_url()}/query/task/container_log?id={id}&token={token}&rank={rank}'
result = await async_requests(RequestMethod.POST, url=url, assert_success=[1])
return result['result']['data']