
本文探讨了自定义对象分类器在面对非训练类别图像时,仍强制返回已知类别的问题。针对这一挑战,文章提出了一种两阶段分类策略:首先进行二元分类以判断目标对象是否存在,若存在,再进行多类别分类以识别具体类别。此方法有效解决了模型在“无匹配项”情况下误报的问题,显著提升了分类器的实用性和用户体验。
在开发基于机器学习的图像识别应用时,一个常见的问题是,当用户上传的图片不属于任何预设的训练类别时,模型仍然会强制性地从已知类别中选择一个作为结果。例如,一个水果检测应用,即使图片中没有任何水果,也可能错误地识别出一种水果。这不仅导致了不准确的输出,也严重影响了用户体验。本文将深入探讨这一问题,并提供一种行之有效的两阶段分类策略来解决它。
当前的代码实现(以及许多标准的深度学习分类模型)本质上是一个单阶段的多类别分类器。其工作原理是:给定一张图片,模型会计算该图片属于每个预设类别的概率或置信度。然后,通过选择置信度最高的类别(即argmax操作),作为最终的分类结果。
// ... (代码省略,表示图像预处理和模型推理) ...
float[] confidences = outputFeature0.getFloatArray();
int maxPos = 0;
float maxConfidence = 0;
for (int i = 0; i < confidences.length; i++) {
if (confidences[i] > maxConfidence) {
maxConfidence = confidences[i];
maxPos = i;
}
}
String[] classes = { /* 所有训练类别 */ };
result.setText(classes[maxPos]); // 总是返回一个已知类别这种方法的问题在于,模型被训练来区分 已知类别之间 的差异,而不是区分 已知类别与未知类别。当输入图像与所有训练类别都相去甚远时,模型仍然会计算出一个“最高”的置信度,即使这个置信度本身很低,也会被选为最终结果。模型内部并没有一个“都不是”的输出选项。
为了解决上述问题,我们可以采用一种两阶段的分类策略。这种方法将识别过程分解为两个独立的逻辑步骤:
这一阶段的目标是构建一个独立的分类器,其任务是判断输入图像中是否包含任何我们感兴趣的对象。对于水果检测应用,这意味着训练一个模型来区分“包含水果的图像”和“不包含水果的图像”。
如果第一阶段的二元分类器判断图像中存在目标对象,那么我们才将图像输入到原有的多类别分类器中,以识别其具体的类别。这一阶段使用的就是用户现有代码中的多类别分类逻辑。
以下是结合两阶段策略的 classifyImage 方法的修改示例。请注意,isFruitPresent 方法是一个概念性的函数,代表了运行二元分类模型并获取结果的过程。
import android.graphics.Bitmap;
import android.view.View;
import android.widget.TextView;
import org.tensorflow.lite.DataType;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
// 假设 FruitDisease 和 BinaryFruitDetectorModel 是通过 TFLite Model Maker 或其他方式生成的模型接口
// import com.example.your_app.ml.FruitDisease;
// import com.example.your_app.ml.BinaryFruitDetectorModel; // 假设这是你的二元分类模型
public class ImageClassifier {
private TextView result; // 假设这是显示最终分类结果的TextView
private TextView confidence; // 假设这是显示置信度列表的TextView
private int imageSize = 224; // 模型输入图片尺寸
// 构造函数或初始化方法,用于传入TextView实例
public ImageClassifier(TextView resultTextView, TextView confidenceTextView) {
this.result = resultTextView;
this.confidence = confidenceTextView;
}
private void classifyImage(Bitmap image) {
// 确保图片尺寸符合模型输入要求
Bitmap scaledImage = Bitmap.createScaledBitmap(image, imageSize, imageSize, false);
try {
// 图像预处理:将Bitmap转换为ByteBuffer,适用于TFLite模型输入
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * imageSize * imageSize * 3);
byteBuffer.order(ByteOrder.nativeOrder());
int[] intValue = new int[imageSize * imageSize];
scaledImage.getPixels(intValue, 0, scaledImage.getWidth(), 0, 0, scaledImage.getWidth(), scaledImage.getHeight());
int pixel = 0;
for (int i = 0; i < imageSize; i++) {
for (int j = 0; j < imageSize; j++) {
int val = intValue[pixel++];
// 归一化到 [0, 1] 范围
byteBuffer.putFloat(((val >> 16) & 0xFF) * (1.f / 255.f)); // R
byteBuffer.putFloat(((val >> 8) & 0xFF) * (1.f / 255.f)); // G
byteBuffer.putFloat((val & 0xFF) * (1.f / 255.f)); // B
}
}
// 创建TensorBuffer作为模型输入
TensorBuffer inputFeature = TensorBuffer.createFixedSize(new int[]{1, imageSize, imageSize, 3}, DataType.FLOAT32);
inputFeature.loadBuffer(byteBuffer);
// =================================================================
// 步骤一:二元分类 - 判断是否存在水果
// 假设我们有一个名为 BinaryFruitDetectorModel 的TFLite模型用于二元分类
boolean fruitDetected = false;
BinaryFruitDetectorModel binaryModel = null; // 声明在try块外部,以便finally关闭
try {
binaryModel = BinaryFruitDetectorModel.newInstance(getApplicationContext()); // 替换为实际的context获取方式
BinaryFruitDetectorModel.Outputs binaryOutputs = binaryModel.process(inputFeature);
TensorBuffer binaryOutputBuffer = binaryOutputs.getOutputFeature0AsTensorBuffer();
float[] binaryConfidences = binaryOutputBuffer.getFloatArray();
// 假设 binaryConfidences[0] 是“无水果”的置信度,binaryConfidences[1] 是“有水果”的置信度
// 或者,如果模型输出是单个值,例如 > 0.5 表示有水果
float fruitPresenceConfidence = binaryConfidences[1]; // 或根据你的模型输出调整
float DETECTION_THRESHOLD = 0.7f; // 设置一个检测阈值
if (fruitPresenceConfidence > DETECTION_THRESHOLD) {
fruitDetected = true;
}
} catch (IOException e) {
e.printStackTrace();
result.setText("二元分类模型加载失败");
confidence.setVisibility(View.GONE);
return;
} finally {
if (binaryModel != null) {
binaryModel.close(); // 关闭二元分类模型
}
}
if (!fruitDetected) {
result.setText("未检测到水果");
confidence.setVisibility(View.GONE);
return; // 如果没有检测到水果,则直接返回
}
// =================================================================
// 步骤二:多类别分类 - 识别具体水果类型(只有在检测到水果后才执行)
FruitDisease multiClassModel = null; // 声明在try块外部,以便finally关闭
try {
multiClassModel = FruitDisease.newInstance(getApplicationContext()); // 替换为实际的context获取方式
FruitDisease.Outputs outputs = multiClassModel.process(inputFeature);
TensorBuffer outputFeature0 = outputs.getOutputFeature0AsTensorBuffer();
float[] confidences = outputFeature0.getFloatArray();
int maxPos = 0;
float maxConfidence = 0;
for (int i = 0; i < confidences.length; i++) {
if (confidences[i] > maxConfidence) {
maxConfidence = confidences[i];
maxPos = i;
}
}
String[] classes = {"Watermelon Healthy", "Watermelon Blossom End Rot", "Watermelon Anthracnose",
"Mango Healthy", "Mango Bacterial Canker", "Mango Anthracnose",
"Orange Scab", "Orange Healthy",
"Orange Bacterial Citrus Canker", "Banana Healthy", "Banana Crown Rot",
"Banana Anthracnose", "Apple Scab", "Apple Healthy", "Apple Black Rot Canker"};
result.setText(classes[maxPos]);
StringBuilder s = new StringBuilder();
for (int i = 0; i < classes.length; i++) {
s.append(String.format("%s: %.1f%%\n", classes[i], confidences[i] * 100));
}
confidence.setText(s.toString());
confidence.setVisibility(View.VISIBLE);
} catch (IOException e) {
e.printStackTrace();
result.setText("多类别分类模型加载失败");
confidence.setVisibility(View.GONE);
} finally {
if (multiClassModel != null) {
multiClassModel.close(); // 关闭多类别分类模型
}
}
} catch (Exception e) {
e.printStackTrace();
result.setText("图像分类过程中发生错误");
confidence.setVisibility(View.GONE);
} finally {
// 确保释放Bitmap资源,如果不再需要
if (scaledImage != null && !scaledImage.isRecycled()) {
scaledImage.recycle();
}
}
}
// 这是一个占位符方法,需要根据你的实际应用上下文获取
// 通常在Activity或Fragment中调用,可以传入getApplicationContext()
private android.content.Context getApplicationContext() {
// 实际应用中,你需要从调用这个classifyImage方法的Activity/Fragment中获取Context
// 例如:return myActivity.getApplicationContext();
throw new UnsupportedOperationException("getApplicationContext() method needs to be implemented by the caller.");
}
}注意事项:
另一种思路是在原始的多类别分类模型中增加一个额外的类别,例如“无水果”或“背景”。这样,模型就有了 N 个水果类别和 1 个“无水果”类别,总共 N+1 个类别。
// 修改后的类别列表,包含一个“无水果”类别
String[] classes = {"Watermelon Healthy", ..., "Apple Black Rot Canker", "No Fruit"};
// 模型训练时也需要包含“No Fruit”类别的样本此方法的缺点:
相比之下,两阶段分类策略在处理“无匹配项”场景时通常更具鲁棒性和可维护性。
当自定义对象分类器需要处理“无匹配项”的输入时,简单地依赖多类别分类器的最高置信度输出是不够的。采用两阶段分类策略,即先通过二元分类判断目标对象是否存在,再进行多类别分类识别具体类型,能够显著提升模型的准确性和应用的健壮性。这种方法不仅改善了用户体验,也为构建更智能、更可靠的机器学习应用提供了有效途径。
以上就是提升自定义对象分类器鲁棒性:处理“无匹配项”场景的策略的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号