整理了从环境安装、ONNX 导出、OM 模型转换、推理验证到算子支持查询的完整流程,并针对 GridSampler2D/3D 的关键问题和解决思路进行了说明。
1. 参考资料
2. 安装环境
1 | pip install torch>=1.9 onnx==1.12.0 onnxruntime==1.14.0 |
- Opset v16
3. 导出 2D GridSample 的 ONNX 模型
- 编写导出脚本
export_grid_sample_onnx.py
1 | vim export_grid_sample_onnx.py |
1 | import torch |
- 运行导出脚本
1 | python3 export_grid_sample.py |
4. 转换成om模型
- 使用 ATC 工具将 ONNX 模型转换为 OM 模型:
1 | atc --model=grid_sample_model.onnx --framework=5 --output=grid_sample_model --input_shape="input_x:1,1,64,64;input_grid:1,64,64,2" --soc_version=Ascend910B3 |
5. 开始推理验证
- 安装ais_bench推理工具:从tools: Ascend tools - Gitee.com,通过源代码编译安装
1 | cd tools-master/ais-bench_workload/tool/ais_bench |
- 生成输入数据
generate_bin.py
1 | mkdir prep_dataset |
1 | import torch |
运行脚本生成二进制输入文件,并创建相应目录:
1 | python3 generate_bin.py |
- 使用 ais_bench 进行推理
1 | cd .. |
6. 更多性能数据
创建或编辑 acl.json
文件:
1 | vim acl.json |
1 | { |
执行推理并采集性能数据:
1 | python3 -m ais_bench --model ./grid_sample_model.om --acl_json_path ./acl.json |
7. 算子支持与注意事项
- 查看CANN算子规格
如果使用GridSampler3D:
1 | vim export_grid_sample_3D.py |
1 | import torch |
1 | python3 export_grid_sample_3D.py |
GridSampler2D 与 GridSampler3D 的限制
PyTorch ONNX 导出限制:目前
torch.onnx.export
仅支持 2D GridSample(4D 输入),不支持 3D(5D)输入,会报错。Ascend AI Core 支持:
- GridSampler3D 有 AI Core Kernel 实现,需 float32 前向;
- GridSampler2D 目前仅支持在 AICPU 上执行,无法使用 AI Core。
3D grid_sample 的计算量通常远大于 2D,如果本来只是个 2D 任务,用 3D 人工加一维度可能并不能带来真正的功能收益,只是为了能够在 AI Core上执行。
这是一个前向支持 / 后向不支持的状况。如果要在 Ascend AI Core 上做推理,又通过 ONNX/OM 流程,目前没有官方开箱可行的方案。
8. 解决思路:
面对上述限制,以下是可能的应对思路:
使用 2D GridSample:如果不强制使用 AI Core 加速,可使用 2D GridSample,会在 AICPU 上运行,性能较低但流程简单。
直接在 PyTorch + Ascend NPU 上推理:跳过 ONNX/OM 转换,使用
torch_npu
在 Ascend NPU 上直接运行模型。使用 GridSampler3D:
构造 5D 输入
x
(形状[N, C, D, H, W]
,dtype float32)和 5D 网格grid
(形状[N, D, H, W, 3]
,dtype float32)。在计算图中添加 GridSampler3D 节点,并设置相关属性。
使用 ATC 或其他编译器,将计算图编译为
.om
文件,在 AI Core 上执行。注意这种方式需要深入了解 Ascend 图编译流程。
自定义实现:自行实现 3D GridSample 的 ONNX symbolic 函数或 Ascend 自定义算子,但工作量较大。
9. 查询算子支持
- 使用
ms_fast_query
工具查询算子支持情况:
1 | cd /usr/local/Ascend/ascend-toolkit/latest/tools/ms_fast_query |
查询结果会生成在指定的 op.json
文件中,可用于查看算子支持详情。
10. 查看算子原型
1 | cd /usr/local/Ascend/ascend-toolkit/latest/opp/built-in/op_proto/ |