图像分类与绘图识别应用开发指南
立即解锁
发布时间: 2025-08-30 00:33:49 阅读量: 12 订阅数: 17 AIGC 


Swift实战AI开发指南
# 图像分类与绘图识别应用开发指南
## 1. 图像分类模型集成到应用中
### 1.1 前期准备
若要将训练好的水果识别模型集成到应用中,你需要先完成起始应用的搭建,可按照相关说明自行构建,也可从指定网站下载名为 `ICDemo - Starter` 的代码和项目。若不想手动添加 AI 功能,还能下载 `ICDemo - Complete` 项目。
### 1.2 代码修改步骤
为使应用能与模型协同工作,需进行以下修改:
1. 在 `inputImage` 和 `classification` 旁添加新变量 `classifier`:
```swift
private let classifier = VisionClassifier(mlmodel: BananaOrApple().model)
```
2. 在 `viewDidLoad()` 末尾将新变量的委托设为 `self`,然后调用 `refresh()`:
```swift
classifier?.delegate = self
refresh()
```
3. 在 `refresh()` 函数的第一个 `if` 语句末尾添加调用,禁用 `classifyImageButton`:
```swift
classifyImageButton.disable()
```
4. 替换 `classifyImage()` 的定义:
```swift
private func classifyImage() {
if let classifier = self.classifier, let image = inputImage {
classifier.classify(image)
classifyImageButton.disable()
}
}
```
### 1.3 添加新文件及代码
接下来,在项目中添加名为 `Vision.swift` 的新 Swift 文件,并添加以下代码:
```swift
import UIKit
import CoreML
import Vision
extension VNImageRequestHandler {
convenience init?(uiImage: UIImage) {
guard let ciImage = CIImage(image: uiImage) else { return nil }
let orientation = uiImage.cgImageOrientation
self.init(ciImage: ciImage, orientation: orientation)
}
}
class VisionClassifier {
private let model: VNCoreMLModel
private lazy var requests: [VNCoreMLRequest] = {
let request = VNCoreMLRequest(
model: model,
completionHandler: {
[weak self] request, error in
self?.handleResults(for: request, error: error)
})
request.imageCropAndScaleOption = .centerCrop
return [request]
}()
var delegate: ViewController?
init?(mlmodel: MLModel) {
if let model = try? VNCoreMLModel(for: mlmodel) {
self.model = model
} else {
return nil
}
}
func classify(_ image: UIImage) {
DispatchQueue.global(qos: .userInitiated).async {
guard let handler =
VNImageRequestHandler(uiImage: image) else {
return
}
do {
try handler.perform(self.requests)
} catch {
self.delegate?.summonAlertView(
message: error.localizedDescription
)
}
}
}
func handleResults(for request: VNRequest, error: Error?) {
DispatchQueue.main.async {
guard let results =
request.results as? [VNClassificationObservation] else {
self.delegate?.summonAlertView(
message: error?.localizedDescription
)
return
}
if results.isEmpty {
self.delegate?.classification = "Don't see a thing!"
} else {
let result = results[0]
if result.confidence < 0.6 {
self.delegate?.classification = "Not quite sure..."
} else {
self.delegate?.classification =
"\(result.identifier) " +
"(\(Int(result.confidence * 100))%)"
}
}
self.delegate?.refresh()
}
}
}
```
在 `Vision.swift` 文件末尾添加以下扩展:
```swift
extension UIImage {
var cgImageOrientation: CGImagePropertyOrientation {
switch self.imageOrientation {
case .up: return .up
case .down: return .down
case .left: return .left
case .right: return .right
case .upMirrored: return .upMirrored
case .downMirrored: return .downMirrored
case .leftMirrored: return .leftMirrored
case .rightMirrored: return .rightMirrored
}
}
}
```
### 1.4 模型导入与应用测试
将 `WhatsMyFruit.mlmodel` 文件拖入项目根目录,允许 Xcode 复制该文件。此时可在模拟器中启动应用,选择图像(若在真机上运行可拍照),点击 “Classify Image” 按钮,模型将进行分类,标签会更新显示分类结果。
### 1.5 应用改进
若想让应用能对更多水果进行分类,可使用 Apple 的 CreateML 应用重新训练模型。选择 `Fruit - 360` 数据集中的完整 `Training` 文件夹(包含 103 种不同水果类别),训练新的图像分类模型。将新模型拖入 Xcode 项目,更新 `ViewController.swift` 中的以下代码:
```swift
private let classifier = VisionClassifier(mlmodel: BananaOrApple().model)
```
例如,若新模型名为 `Fruits360.mlmodel`,则更新为:
```swift
private let classifier = VisionClassifier(mlmodel: Fruits360().model)
```
重新启动应用,即可检测 103 种不同的水果。
## 2. 绘图识别应用开发
### 2.1 问题与解决思路
随着 iPad Pro 和 Apple Pencil 的出现,在苹果移动设备上绘图变得愈发流行。开发绘图识别应用有诸多用途,如开发绘图游戏、将绘制内容转换为表情符号等。我们将通过以下步骤探索绘图检测的实践方法:
- 构建一个允许用户拍摄绘图照片并进行分类的应用。
- 寻找或组装数据,训练能对绘图进行分类的模型。
- 探索提升绘图分类效果的后续步骤。
### 2.2 AI 工具包与数据集
此任务主要使用 Turi Create、CoreML 和 Vision 工具。我们将使用 Turi Create 训练绘图分类模型,再用 CoreML 和 Vision 与模型协作,对用户拍摄的绘图照片进行分类。
为训练模型,需要绘图数据集。Google 的 Quick Draw 数据集包含超过 5000 万张草图,分为 345 个类别。由于类别过多,训练时间较长,我们选择以下 23 个类别:苹果、香蕉、面包、西兰花、蛋糕、胡萝卜、咖啡杯、饼干、甜甜圈、葡萄、热狗、冰淇淋、棒棒糖、蘑菇、花生、梨、菠萝、披萨、土豆、三明治、牛排、草莓和西瓜。
### 2.3 模型创建
#### 2.3.1 Python 环境搭建
按照相关流程搭建 Python 环境,激活环境并使用 `pip` 安装 Turi Create:
```bash
conda create -n TuriCreateDrawingClassifierEnvironment python=3.6
conda activate TuriCreateDrawingClassifierEnvironment
pip install turicreate
```
#### 2.3.2 创建 Python 脚本
创建名为 `train_drawing_classifier.py` 的新 Python 脚本,并添加以下代码:
```python
#!/usr/bin/env python
import os
import json
import requests
import numpy as np
import turicreate as tc
```
#### 2.3.3 配置变量
添加配置变量,包括要训练的类别列表:
```python
# THE CATEGORIES WE WANT TO BE ABLE TO DISTINGUISH
categories = [
'apple', 'banana', 'bread', 'broccoli', 'cake', 'carrot', 'coffee cup',
'cookie', 'donut', 'grapes', 'hot dog', 'ice cream', 'lollipop',
'mushroom', 'peanut', 'pear', 'pineapple', 'pizza', 'potato',
'sandwich', 'steak', 'strawberry', 'watermelon'
]
# CONFIGURE AS REQUIRED
this_directory = os.path.dirname(os.path.realpath(__file__))
quickdraw_directory = this_directory + '/quickdraw'
bitmap_directory = quickdraw_directory + '/bitmap'
bitmap_sframe_path = quickdraw_directory + '/bitmaps.sframe'
output_model_filename = this_directory + '/DrawingClassifierModel'
training_samples = 10000
```
#### 2.3.4 创建目录
添加函数创建用于存放训练数据的目录:
```python
# MAKE SOME FOLDERS TO PUT TRAINING DATA IN
def make_directory(path):
try:
os.makedirs(path)
except OSError:
if not os.path.isdir(path):
raise
make_directory(quickdraw_directory)
make_directory(bitmap_directory)
```
#### 2.3.5 下载训练数据
下载用于训练的位图数据:
```python
# FETCH SOME DATA
bitmap_url = (
'https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap'
)
total_categories = len(categories)
for index, category in enumerate(categories):
bitmap_filename = '/' + category + '.npy'
with open(bitmap_directory + bitmap_filename, 'w+') as bitmap_file:
bitmap_response = requests.get(bitmap_url + bitmap_filename)
bitmap_file.write(bitmap_response.content)
print('Downloaded %s drawings (category %d/%d)' %
(category, index + 1, total_categories))
random_state = np.random.RandomState(100)
```
#### 2.3.6 创建 SFrames
添加函数从图像创建 SFrames:
```python
def get_bitmap_sframe():
labels, drawings = [
```
0
0
复制全文
相关推荐










