/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <string>
#include <functional>
#include <nlohmann/json.hpp>
#include "atb/infer_op_params.h"
#include "atb/train_op_params.h"
#include "atb/operation.h"
#include "operation_factory.h"
using CreateOperationFuncPtr = std::function<atb::Operation *(const nlohmann::json &)>;
static atb::Operation *ActivationOperationCreate(const nlohmann::json ¶mJson)
{
atb::infer::ActivationParam param;
if (paramJson.contains("activationType")) {
param.activationType = atb::infer::ActivationType(paramJson["activationType"].get<int32_t>());
}
if (paramJson.contains("scale")) {
param.scale = paramJson["scale"].get<float>();
}
if (paramJson.contains("dim")) {
param.dim = paramJson["dim"].get<int32_t>();
}
if (paramJson.contains("geluMode")) {
param.geluMode = atb::infer::ActivationParam::GeLUMode(paramJson["geluMode"].get<int>());
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *AllGatherOperationCreate(const nlohmann::json ¶mJson)
{
atb::infer::AllGatherParam param;
param.rank = paramJson["rank"].get<int>();
param.rankSize = paramJson["rankSize"].get<int>();
if (paramJson.find("rankRoot") != paramJson.end()) {
param.rankRoot = paramJson["rankRoot"].get<int>();
}
if (paramJson.find("backend") != paramJson.end()) {
param.backend = paramJson["backend"].get<std::string>();
}
if (paramJson.contains("commMode")) {
param.commMode = atb::infer::CommMode(paramJson["commMode"].get<int>());
}
if (paramJson.find("rankTableFile") != paramJson.end()) {
param.rankTableFile = paramJson["rankTableFile"].get<std::string>();
}
if (paramJson.find("commDomain") != paramJson.end()) {
param.commDomain = paramJson["commDomain"].get<std::string>();
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *AllReduceOperationCreate(const nlohmann::json ¶mJson)
{
atb::infer::AllReduceParam param;
param.rank = paramJson["rank"].get<int>();
param.rankSize = paramJson["rankSize"].get<int>();
if (paramJson.find("rankRoot") != paramJson.end()) {
param.rankRoot = paramJson["rankRoot"].get<int>();
}
if (paramJson.find("backend") != paramJson.end()) {
param.backend = paramJson["backend"].get<std::string>();
}
if (paramJson.find("allReduceType") != paramJson.end()) {
param.allReduceType = paramJson["allReduceType"].get<std::string>();
}
if (paramJson.contains("commMode")) {
param.commMode = atb::infer::CommMode(paramJson["commMode"].get<int>());
}
if (paramJson.find("rankTableFile") != paramJson.end()) {
param.rankTableFile = paramJson["rankTableFile"].get<std::string>();
}
if (paramJson.find("commDomain") != paramJson.end()) {
param.commDomain = paramJson["commDomain"].get<std::string>();
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *AsStridedOperationCreate(const nlohmann::json ¶mJson)
{
atb::infer::AsStridedParam param;
for (auto item : paramJson["size"]) {
param.size.push_back(item.get<int64_t>());
}
for (auto item : paramJson["stride"]) {
param.stride.push_back(item.get<int64_t>());
}
for (auto item : paramJson["offset"]) {
param.offset.push_back(item.get<int64_t>());
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *BroadcastOperationCreate(const nlohmann::json ¶mJson)
{
atb::infer::BroadcastParam param;
param.rank = paramJson["rank"].get<int>();
param.rankSize = paramJson["rankSize"].get<int>();
if (paramJson.find("rankRoot") != paramJson.end()) {
param.rankRoot = paramJson["rankRoot"].get<int>();
}
if (paramJson.contains("commMode")) {
param.commMode = atb::infer::CommMode(paramJson["commMode"].get<int>());
}
if (paramJson.find("backend") != paramJson.end()) {
param.backend = paramJson["backend"].get<std::string>();
}
if (paramJson.find("rankTableFile") != paramJson.end()) {
param.rankTableFile = paramJson["rankTableFile"].get<std::string>();
}
if (paramJson.find("commDomain") != paramJson.end()) {
param.commDomain = paramJson["commDomain"].get<std::string>();
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *ConcatOperationCreate(const nlohmann::json ¶mJson)
{
atb::infer::ConcatParam param;
if (paramJson.contains("concatDim")) {
param.concatDim = paramJson["concatDim"].get<int>();
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *CumsumOperationCreate(const nlohmann::json ¶mJson)
{
atb::infer::CumsumParam param;
for (auto item : paramJson["axes"]) {
param.axes.push_back(item.get<int64_t>());
}
if (paramJson.contains("exclusive")) {
param.exclusive = paramJson["exclusive"].get<bool>();
}
if (paramJson.contains("reverse")) {
param.reverse = paramJson["reverse"].get<bool>();
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *ElewiseOperationCreate(const nlohmann::json ¶mJson)
{
atb::infer::ElewiseParam param;
param.elewiseType = paramJson["elewiseType"].get<atb::infer::ElewiseParam::ElewiseType>();
if (paramJson.contains("varAttr")) {
param.mulsParam.varAttr = paramJson["varAttr"].get<float>();
}
if (paramJson.contains("outTensorType")) {
param.outTensorType = paramJson["outTensorType"].get<aclDataType>();
}
if (paramJson.contains("inputScale")) {
param.quantParam.inputScale = paramJson["inputScale"].get<float>();
}
if (paramJson.contains("inputOffset")) {
param.quantParam.inputOffset = paramJson["inputOffset"].get<int>();
}
if (paramJson.contains("asymmetric")) {
param.quantParam.asymmetric = paramJson["asymmetric"].get<bool>();
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *FastSoftMaxGradOperationCreate(const nlohmann::json ¶mJson)
{
atb::train::FastSoftMaxGradParam param;
if (paramJson.contains("headNum")) {
param.headNum = paramJson["headNum"].get<int32_t>();
}
if (paramJson.contains("qSeqLen")) {
for (auto item : paramJson["qSeqLen"]) {
param.qSeqLen.push_back(item.get<int32_t>());
}
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *FastSoftMaxOperationCreate(const nlohmann::json ¶mJson)
{
atb::train::FastSoftMaxParam param;
if (paramJson.contains("headNum")) {
param.headNum = paramJson["headNum"].get<int32_t>();
}
if (paramJson.contains("qSeqLen")) {
for (auto item : paramJson["qSeqLen"]) {
param.qSeqLen.push_back(item.get<int32_t>());
}
}
atb::Operation *op;
CreateOperation(param, &op);
return op;
}
static atb::Operation *FillOperationCreate(const nlohmann::json ¶mJson)
{
atb::infer::FillParam param;
if (paramJson.contains("withMask")) {
param.withMask = pa
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
本项目为基于Python的AI一体化开发与调试工具设计源码,包含955个文件,涵盖485个Python脚本、119个Markdown文档、42个Java类文件、41个PNG图像文件、39个SVG图形文件、38个Shell脚本文件、37个JavaScript文件、33个C++源文件、21个C头文件、22个文本文件以及21个HTML文件。该工具链提供统一的推理工具入口,支持一站式开发与调试调优,旨在为用户提供高效便捷的AI开发体验。
资源推荐
资源详情
资源评论
收起资源包目录
基于Python的AI一体化开发与调试工具设计源码 (856个子文件)
install.bat 3KB
model.cfg 692B
aipp.config 320B
operation_creator.cpp 34KB
main.cpp 5KB
graph_utils.cpp 3KB
aie_convert.cpp 2KB
export_om_model.cpp 1KB
公网URL使用说明.csv 14KB
Dockerfile 6KB
graph_refactor_demo.gif 249KB
.gitignore 2KB
.gitignore 203B
.gitkeep 0B
.gitkeep 0B
.gitkeep 0B
.gitkeep 0B
.gitkeep 0B
.gitkeep 0B
.gitkeep 0B
.gitkeep 0B
.gitkeep 0B
operation_factory.h 1KB
graph_utils.h 944B
config.ini 668B
config.ini 52B
cat.jpg 76KB
dog.jpg 13KB
test.om.json 509KB
model.json 209KB
Bloom7BFlashAttentionModel.json 22KB
opp.json 2KB
model_tree.json 2KB
test.onnx.json 444B
.keep 0B
.keep 0B
.keep 0B
LICENSE 11KB
knowledge_optimizer_rules.md 23KB
API_GUIDE.md 21KB
大模型精度问题定位全流程.md 20KB
llm大模型推理精度工具功能说明_v0.2.1.md 16KB
大模型推理精度工具说明文档.md 16KB
工具-llm模型迁移分析使用说明.md 15KB
graph_refactor_API.md 15KB
工具-精度预检使用说明.md 12KB
README_EN.md 12KB
工具-DUMP加速库数据使用说明.md 12KB
knowledge_optimizer_framework.md 11KB
TorchAir场景-整网算子精度比对.md 9KB
精度预检能力使用说明.md 8KB
basic_usage.md 8KB
自动映射比对能力说明.md 7KB
加速库场景-整网精度比对.md 7KB
FAQ.md 7KB
llm大模型推理精度工具功能说明_v0.2.0.md 6KB
工具-BadCase分析使用说明.md 6KB
工具-DUMP在线推理数据使用说明.md 6KB
graph_refactor_BaseGraph.md 6KB
introduction.md 5KB
如何识别 Bad Case.md 5KB
history.md 5KB
graph_refactor_BaseNode.md 5KB
加速库场景-输出Token的logits精度比对.md 5KB
FAQ.md 5KB
精度比对结果参数说明.md 5KB
工具-手动映射比对能力说明.md 4KB
doc-guidelines.md 4KB
工具-自动比对功能使用说明.md 4KB
llm大模型推理精度工具功能说明_v0.1.0.md 3KB
工具-异常检测使用说明.md 3KB
MindIE-Torch场景-整网算子精度对比.md 3KB
TorchAir场景Dump案例.md 2KB
graph_optimizer.md 2KB
FAQ.md 2KB
README.md 864B
FAQ.md 864B
FAQ.md 864B
OWNERS 736B
torch_topo.png 236KB
acc-workflow.png 218KB
atb_topo.png 186KB
ait_flow.png 172KB
msit-flow.png 104KB
api_quick_start.png 88KB
cmp_report.png 86KB
chatglm6b_cmp_result.png 85KB
7bcb6c78-a839-4dad-bb0a-4abcba10694c.png 81KB
比对报告.PNG 69KB
msit-llm-flow.png 66KB
problem_node.png 66KB
bloom7b_cmp_result.png 65KB
excerpt.png 63KB
basenode.png 62KB
LocationProgress.png 49KB
inference.png 37KB
matched_pie.png 34KB
single_node.png 30KB
acl_pta_workflow.png 30KB
说明.png 29KB
共 856 条
- 1
- 2
- 3
- 4
- 5
- 6
- 9
资源评论
wjs2024
- 粉丝: 2368
- 资源: 5526
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功