强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

LLVM 开发指南 / 第 17 章:MLIR 多级 IR

第 17 章:MLIR 多级 IR

“MLIR 不只是一个 IR,而是一个构建 IR 的框架。”


17.1 MLIR 概述

MLIR(Multi-Level Intermediate Representation) 是 LLVM 项目中的通用编译器基础设施,旨在解决传统编译器中的"方言碎片化"问题。

17.1.1 为什么需要 MLIR?

传统编译器的问题:
  每个领域 (深度学习、HPC、量子计算) 都有自己的 IR
  无法共享优化和基础设施
  重复劳动

MLIR 的解决方案:
  提供统一的可扩展框架
  通过 Dialect 支持任意抽象级别
  渐进式降低 (Progressive Lowering)

17.1.2 MLIR 应用场景

场景说明用户
深度学习TensorFlow → MLIR → GPU/TPUTensorFlow, PyTorch
HPC 编译循环优化、向量化Polygeist, Triton
硬件加速器FPGA/ASIC 编译CIRCT, HEDA
领域特定语言DSL 编译器数学、物理引擎
源到源翻译代码变换MLIR-to-MLIR

17.2 核心概念

17.2.1 Dialect(方言)

Dialect 是 MLIR 中组织 Operation 和 Type 的命名空间:

┌──────────────────────────────────────┐
│              MLIR                     │
│                                      │
│  ┌──────────┐  ┌──────────┐         │
│  │ arith    │  │  func    │         │
│  │ 算术运算  │  │ 函数定义  │         │
│  └──────────┘  └──────────┘         │
│                                      │
│  ┌──────────┐  ┌──────────┐         │
│  │  linalg  │  │ tensor   │         │
│  │ 线性代数  │  │ 张量操作  │         │
│  └──────────┘  └──────────┘         │
│                                      │
│  ┌──────────┐  ┌──────────┐         │
│  │   scf    │  │  memref  │         │
│  │ 结构控制流│  │ 内存引用  │         │
│  └──────────┘  └──────────┘         │
│                                      │
│  ┌──────────┐  ┌──────────┐         │
│  │  llvm    │  │  gpu     │         │
│  │ LLVM 方言│  │ GPU 操作  │         │
│  └──────────┘  └──────────┘         │
└──────────────────────────────────────┘

17.2.2 Operation(操作)

// MLIR 操作格式:
%result = dialect.operation_name %operand1, %operand2 {attr = value} : (type1, type2) -> type3

// 示例:
%sum = arith.addi %a, %b : i32                    // 整数加法
%prod = arith.mulf %x, %y : f32                   // 浮点乘法
%cmp = arith.cmpi eq, %a, %b : i32                // 整数比较

// 函数定义
func.func @add(%arg0: i32, %arg1: i32) -> i32 {
  %result = arith.addi %arg0, %arg1 : i32
  return %result : i32
}

// 控制流
scf.for %i = %lb to %ub step %step {
  %val = memref.load %arr[%i] : memref<10xi32>
  scf.yield
}

// 张量操作
%c = linalg.matmul ins(%a, %b : tensor<4x8xf32>, tensor<8x16xf32>)
                   outs(%c : tensor<4x16xf32>) -> tensor<4x16xf32>

17.2.3 Block 和 Region

// Region 包含 Block
// Block 包含 Operation
// Operation 可以包含 Region (嵌套)

func.func @example(%cond: i1) -> i32 {
  // 这里是一个 Region
  // 包含一个 Block (entry block)

  cf.cond_br %cond, ^bb1, ^bb2  // 条件分支

^bb1:
  %x = arith.constant 42 : i32
  cf.br ^bb3(%x : i32)

^bb2:
  %y = arith.constant 0 : i32
  cf.br ^bb3(%y : i32)

^bb3(%result: i32):
  return %result : i32
}

17.3 MLIR 类型系统

// 基础类型
i32                     // 32位整数
f64                     // 64位浮点
index                   // 索引类型 (平台相关)

// 张量类型
tensor<4xf32>           // 1D张量,4个f32
tensor<3x4xf32>         // 2D张量
tensor<?x?xf32>         // 动态形状张量
tensor<4x?x8xi32>       // 部分动态

// 内存引用类型
memref<10xf32>          // 1D内存,10个f32
memref<3x4xf32>         // 2D内存
memref<?xf32>           // 动态大小
memref<10xf32, affine_map<(d0) -> (d0)>, 1>  // 带布局和地址空间

// 向量类型
vector<4xf32>           // SIMD向量
vector<4x8xf32>         // 2D向量

// 函数类型
(i32, i32) -> i32       // 两参数一返回值
(i32) -> ()             // 一参数无返回值

// 索引映射类型
affine_map<(d0, d1) -> (d0, d1)>    // 仿射映射

17.4 常用 Dialect

Dialect用途抽象级别
func函数定义高层
arith算术运算中层
math数学函数中层
tensor张量操作高层
linalg线性代数高层
memref内存操作中层
scf结构化控制流中层
cf非结构化控制流低层
affine仿射循环高层
llvmLLVM 方言最低层
gpuGPU 操作高层
spirvSPIR-VGPU后端
tosaTensor 操作AI 专用

17.5 渐进式降低 (Progressive Lowering)

高级别表示:
  %c = linalg.matmul ins(%a, %b : tensor<4x8xf32>, tensor<8x16xf32>)
                     outs(%c : tensor<4x16xf32>)
        ↓ tensor → memref
中级别表示:
  linalg.matmul ins(%a, %b : memref<4x8xf32>, memref<8x16xf32>)
                outs(%c : memref<4x16xf32>)
        ↓ linalg → scf + arith
循环表示:
  scf.for %i = ... {
    scf.for %j = ... {
      scf.for %k = ... {
        %a_val = memref.load %a[%i, %k]
        %b_val = memref.load %b[%k, %j]
        %c_val = memref.load %c[%i, %j]
        %prod = arith.mulf %a_val, %b_val
        %sum = arith.addf %c_val, %prod
        memref.store %sum, %c[%i, %j]
      }
    }
  }
        ↓ scf → cf + memref → llvm
LLVM IR 表示:
  llvm.func @matmul(...) {
    llvm.br ^bb1
    ...
  }
        ↓ LLVM 后端
机器码

17.6 Dialect 变换 Pass

17.6.1 内置 Pass

# 使用 mlir-opt 运行变换
mlir-opt input.mlir -pass-name -o output.mlir

# 常用变换
mlir-opt input.mlir \
  -linalg-bufferize \
  -tensor-bufferize \
  -func-bufferize \
  -convert-linalg-to-loops \
  -convert-scf-to-cf \
  -convert-memref-to-llvm \
  -convert-arith-to-llvm \
  -convert-func-to-llvm \
  -reconcile-unrealized-casts \
  -o output.mlir

# 一键降级到 LLVM 方言
mlir-opt input.mlir \
  -one-shot-bufferize \
  -convert-linalg-to-loops \
  -lower-affine \
  -convert-scf-to-cf \
  -finalize-memref-to-llvm \
  -convert-func-to-llvm \
  -convert-arith-to-llvm \
  -reconcile-unrealized-casts \
  -o output.mlir

17.6.2 编写自定义变换

// MyTransformPass.cpp
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

struct MyPattern : public OpRewritePattern<arith::AddIOp> {
    using OpRewritePattern::OpRewritePattern;

    LogicalResult matchAndRewrite(arith::AddIOp op,
                                   PatternRewriter &rewriter) const override {
        // 匹配 add x, 0 → x
        auto rhs = op.getRhs().getDefiningOp<arith::ConstantOp>();
        if (rhs && rhs.getValue().cast<IntegerAttr>().getInt() == 0) {
            rewriter.replaceOp(op, op.getLhs());
            return success();
        }
        return failure();
    }
};

struct MyPass : public PassWrapper<MyPass, OperationPass<ModuleOp>> {
    MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MyPass)

    StringRef getArgument() const override { return "my-pass"; }
    StringRef getDescription() const override { return "My custom pass"; }

    void runOnOperation() override {
        MLIRContext *ctx = &getContext();
        RewritePatternSet patterns(ctx);
        patterns.add<MyPattern>(ctx);
        
        if (failed(applyPatternsAndFoldGreedily(
                getOperation(), std::move(patterns)))) {
            signalPassFailure();
        }
    }
};

// 注册 Pass
void registerMyPass() {
    PassRegistration<MyPass>();
}

17.7 MLIR 工具链

工具功能
mlir-opt运行变换 Pass
mlir-translateIR 翻译(MLIR ↔ LLVM IR)
mlir-cpu-runnerCPU 上执行 MLIR
mlir-tblgenDialect/TableGen 代码生成
mlir-lsp-server语言服务器
# MLIR → LLVM IR
mlir-translate --mlir-to-llvmir input.mlir -o output.ll

# LLVM IR → MLIR
mlir-translate --import-llvm input.ll -o output.mlir

# 执行 MLIR
mlir-cpu-runner -e main -entry-point-result=i32 \
  -shared-libs=/usr/lib/libmlir_runner_utils.so input.mlir

17.8 本章小结

概念说明
Dialect操作和类型的命名空间
OperationIR 的基本单元
Region/Block嵌套结构
Progression Lowering逐步降低抽象级别
BufferizationTensor → Memref 转换

扩展阅读

  1. MLIR 官方文档 — MLIR 项目主页
  2. MLIR Tutorial — Toy 语言教程
  3. MLIR Dialects — 方言文档
  4. MLIR Language Reference — 语言参考

下一章: 第 18 章:开发环境与构建系统 — 学习 LLVM 的 Docker 开发环境和 CMake 集成。