在视觉融合与图像处理的工业应用中,”抠图”(Matting)是一个非常基础且关键的前置步骤。传统的图像处理算法(如OpenCV中的GrabCut、色度键控等)往往需要人工交互或对背景颜色有严格限制。而随着深度学习的发展,基于模型的自动抠图方案在精度和鲁棒性上都取得了质的飞跃。

本文将介绍如何在Linux环境下,使用Python结合ONNX Runtime部署 RMBG-1.4 模型,实现高效的自动抠图,并结合K-Means聚类进行色彩量化处理,为后续的工业应用(如刺绣制版、矢量化)做准备。

1. 方案选型与环境准备

1.1 核心模型

本项目选用的模型是 RMBG-1.4(Background Removal Model)。相较于庞大的PyTorch/TensorFlow环境,我们选择导出为 ONNX 格式的模型进行部署。ONNX Runtime提供了跨平台的推理加速能力,非常适合在Linux生产环境或嵌入式设备中运行。

1.2 依赖库

我们需要以下Python库:

  • onnxruntime:用于加载和推理ONNX模型。
  • opencv-python:图像读取、缩放与基本处理。
  • numpy:矩阵运算。
  • Pillow (PIL):图像格式转换与Alpha通道处理。
  • kneed:用于自动确定K-Means聚类的最佳K值(”肘部法则”)。

安装命令:

1
pip install onnxruntime opencv-python numpy pillow kneed

2. 核心代码实现

我们将核心功能封装在 RMBG2 类中,遵循 预处理 -> 推理 -> 后处理 的标准流水线。

2.1 模型加载与初始化

使用 ort.InferenceSession 加载模型,这里我们指定使用 CPU 推理(CPUExecutionProvider),如果是在Jetson等设备上,可切换为GPU相关Provider。

1
2
3
4
5
6
7
8
9
10
import onnxruntime as ort

class RMBG2():
def __init__(self, model_ptah) -> None:
self.sess_opts = ort.SessionOptions()
self.sess_opts.log_severity_level = 3
provider = ["CPUExecutionProvider"]
self.session = ort.InferenceSession(model_ptah, providers=provider, sess_options=self.sess_opts)
self.input_name = self.session.get_inputs()[0].name
self.input_shape = (1024, 1024) # 模型固定的输入尺寸

2.2 预处理 (Preprocess)

模型通常需要特定的输入尺寸(如1024x1024)和归一化方式。

  1. 尺寸调整:Resize到1024x1024。
  2. 归一化:将像素值从 [0, 255] 映射到 [0, 1],再进行标准化 (image - 0.5) / 1.0
  3. 维度变换:也就是HWC转CHW(Channel, Height, Width),并增加Batch维度。
1
2
3
4
5
6
7
8
def preprocess(self, image: np.ndarray) -> np.ndarray:
if len(image.shape) < 3:
image = np.expand_dims(image, axis=2)
image = cv2.resize(image, self.input_shape, interpolation=cv2.INTER_LINEAR)
image = image.astype(np.float32) / 255.0
image = (image - 0.5) / 1.0
image = np.transpose(image, (2, 0, 1))
return np.expand_dims(image, axis=0)

2.3 后处理 (Postprocess)

模型的输出通常是一个Mask(掩码),需要将其还原回原图尺寸,并处理成标准的Alpha通道。
线性插值还原尺寸后,我们将结果归一化到 [0, 255] 区间。

1
2
3
4
5
6
7
8
9
def postprocess(self, result: np.ndarray, original_size: tuple) -> np.ndarray:
result = cv2.resize(
np.squeeze(result),
original_size[::-1],
interpolation=cv2.INTER_LINEAR,
)
max_val, min_val = np.max(result), np.min(result)
result = (result - min_val) / (max_val - min_val)
return (result * 255).astype(np.uint8)

2.4 推理主流程

将原图输入模型,获取Mask,然后使用PIL库将Mask应用到原图的Alpha通道,实现背景透明化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def infer(self, image):
blob = self.preprocess(image)
output = self.forward(blob)
result_mask = self.postprocess(output, image.shape[:2])

# 使用PIL进行通道合成
pil_mask = Image.fromarray(result_mask)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
pil_image = pil_image.convert("RGBA")
pil_mask = pil_mask.convert("L")

output_image = Image.new("RGBA", pil_image.size, (0, 0, 0, 0))
output_image.paste(pil_image, (0, 0), pil_mask)
return output_image

3. 面向工业的后处理:色彩量化

在刺绣、印花等工业场景中,仅仅把图抠出来是不够的,我们往往还需要减少颜色数量(Color Quantization),将成千上万种颜色简化为几十种特定的线色或墨色。

3.1 自动寻找最佳K值

使用 K-Means 聚类可以提取主色,但 K 值(颜色数量)的选择是个难题。我们利用 KneeLocator 寻找“肘部点”(Knee Point),即随着K增加,聚类效果提升不再显著的那个转折点。

为了速度,我们先对前景像素进行采样(Sample Rate = 0.1),在小样本上寻找最佳K。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from kneed import KneeLocator

def fast_find_best_k(fg_pixels, k_min=4, k_max=16, sample_rate=0.1):
# 下采样以加速
if fg_pixels.shape[0] > 1000:
idx = np.random.choice(fg_pixels.shape[0], int(fg_pixels.shape[0]*sample_rate), replace=False)
sample_pixels = fg_pixels[idx]
else:
sample_pixels = fg_pixels

inertias = []
Z = sample_pixels.astype(np.float32)
# K-Means 只跑少量迭代
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)

for k in range(k_min, k_max + 1):
compactness, _, _ = cv2.kmeans(Z, k, None, criteria, 5, cv2.KMEANS_RANDOM_CENTERS)
inertias.append(compactness)

# 寻找肘部
ks = list(range(k_min, k_max + 1))
kneedle = KneeLocator(ks, inertias, S=1.0, curve="convex", direction="decreasing")
best_k = kneedle.knee if kneedle.knee else k_min
return best_k

3.2 区域连通域处理

聚类后的图像可能会有噪点,我们还可以结合连通域分析 (cv2.connectedComponents),去除面积过小的杂色块,将其合并到邻近的主色块中,使图像更加干净、适合矢量转换。

(注:这部分逻辑在代码中的 replace_by_region 函数体现)

4. 服务化部署

为了方便与其他系统集成,我们可以使用 Flask 快速搭建一个 HTTP 微服务。

1
2
3
4
5
6
7
from flask import Flask, send_file
# ... 代码省略 ...

@app.route('/matting', methods=['POST'])
def matting_api():
# 接收上传图片 -> 调用 RMBG2.infer -> 返回 PNG
pass

通过这种方式,我们可以将复杂的深度学习推理封装在独立的微服务中,前端或上位机软件只需简单的 API 调用即可获得高质量的透明背景图片。

5. 总结

本文展示的不仅仅是一个抠图 Demo,而是一个面向实际工业需求的图像处理 Pipeline:

  1. 高精度抠图:利用 ONNX Runtime 加速 RMBG 模型。
  2. 智能色彩简化:结合 K-Means 和 肘部法则 自动决定颜色数量。
  3. 工程化落地:Python + Flask + Linux 的标准组合。

这种“AI模型 + 传统CV后处理”的混合模式,往往能解决纯端到端模型难以覆盖的定制化需求。