CreateML 使用以及在 iOS 中应用介绍

2457次阅读  |  发布于2年以前

aPaaS Growth 团队专注在用户可感知的、宏观的 aPaaS 应用的搭建流程,及租户、应用治理等产品路径,致力于打造 aPaaS 平台流畅的 “应用交付” 流程和体验,完善应用构建相关的生态,加强应用搭建的便捷性和可靠性,提升应用的整体性能,从而助力 aPaaS 的用户增长,与基础团队一起推进 aPaaS 在企业内外部的落地与提效。

在低代码/无代码领域,例如 MS Power Platform,AWS 的 Amplify 都有类似于 AI Builder 的产品,这些产品主要让用户很低门槛训练自己的深度学习模型,CreateML 是苹果生态下的产品,工具上伴随 XCode 下发,安装了 XCode 的同学也可以打开来体验一下(得自己准备数据集)。

什么是 CreateML

Create ML 是苹果于2018年 WWDC 推出的生成机器学习模型的工具。它可以接收用户给定的数据,生成 iOS 开发中需要的机器学习模型(Core ML 模型)。

iOS 开发中,机器学习模型的获取主要有以下几种:

今年的 Create ML 在易用性上更进一步:无需编程即可完成操作、独立成单独的 Mac OS App、支持更多的数据类型和使用场景。

CreateML 模型列表

  1. Image Classification:图片分类

2. Object Detection:

3. Style Transfer

4. Hand Pose & Hand Action

5. Action Classification

6. Activity Classification

6 . Sound Classification

想象一下「Hey Siri」实现

7. Text Classification

8. Word Tagging

9. Tabular Classification & Regression

通过若干个维度,预测另外一个维度,例如通过性别、年龄、城市等推断你的收入级别。

10 . Recommendation

例如你买了啤酒,推荐你买花生。历史上的也有一些不是基于深度学习的算法,例如 Apriori 等。

CreateML 模型尝鲜

训练一个目标检测的 CreateML 模型

数据准备

有些同学可能认为觉得训练深度模型的难点在于找到适当的算法/模型、在足够强的机器下训练足够多的迭代次数。但是事实上,对于深度模型来说,最最最关键的是具有足够多的、精确的数据源,这也是 AI 行业容易形成头部效应最主要原因。假设你在做一个 AI 相关的应用,最主要需要关注的是如何拥有足够多的、精确的数据源。

下面我就与上面「尝鲜」的模型为例,讲述如何训练类似模型的。

数据格式

CreateML 目标检测的数据格式如下图:

首先会有一个叫 annotions.json 的文件,这个文件会标注每个文件里有多少个目标,以及目标的 Bounding Box 的坐标是什么。

例如上图对应的 Bounding Box 如下:

准备足够多的数据

第一个问题是,什么才叫足够多的数据,我们可以看一些 Dataset 来参考一下:

  1. Standford Cars Dataset: 934MB. The Cars dataset contains 16,185 images of 196 classes of cars. The data is split into 8,144 training images and 8,041 testing images。
  2. https://www.kaggle.com/datasets/kmader/food41: Labeled food images in 101 categories from apple pies to waffles, 6GB

在上面这个例子里,原神的角色有大概 40 多个,所以我们需要准备大概百来 MB 的数据来训练作为起来,当精确度不高的时候,再增加样本的数量来增加精度。问题是我们去哪里找那么多数据呢?所以我想到的一个方法是通过脚本来合成,因为我们的问题只是定位提取图片中的角色「证件照」,我用大概 40 来角色的证件照,写了如下的脚本(colipot helped a alot ...)来生成大概 500MB 的测试训练集:

// import sharp from "sharp";

import { createCanvas, Image } from "@napi-rs/canvas";
import { promises } from "fs";
import fs from "fs";
import path from "path";
import Sharp from "sharp";

const IMAGE_GENERATED_COUNT_PER_CLASS = 5;
const MAX_NUMBER_OF_CLASSES_IN_SINGLE_IMAGE = 10;
const CANVAS_WIDTH = 1024;
const CANVAS_HEIGHT = 800;
const CONCURRENT_PROMISE_SIZE = 50;

const CanvasSize = [CANVAS_WIDTH, CANVAS_HEIGHT];

function isNotOverlap(x1: number, y1: number, width1: number, height1: number, x2: number, y2: number, width2: number, height2: number) {
    return x1 >= x2 + width2 || x1 + width1 <= x2 || y1 >= y2 + height2 || y1 + height1 <= y2;
}

const randomColorList: Record<string, string> = {
    "white": "rgb(255, 255, 255)",
    "black": "rgb(0, 0, 0)",
    "red": "rgb(255, 0, 0)",
    "green": "rgb(0, 255, 0)",
    "blue": "rgb(0, 0, 255)",
    "yellow": "rgb(255, 255, 0)",
    "cyan": "rgb(0, 255, 255)",
    "magenta": "rgb(255, 0, 255)",
    "gray": "rgb(128, 128, 128)",
    "grey": "rgb(128, 128, 128)",
    "maroon": "rgb(128, 0, 0)",
    "olive": "rgb(128, 128, 0)",
    "purple": "rgb(128, 0, 128)",
    "teal": "rgb(0, 128, 128)",
    "navy": "rgb(0, 0, 128)",
    "orange": "rgb(255, 165, 0)",
    "aliceblue": "rgb(240, 248, 255)",
    "antiquewhite": "rgb(250, 235, 215)",
    "aquamarine": "rgb(127, 255, 212)",
    "azure": "rgb(240, 255, 255)",
    "beige": "rgb(245, 245, 220)",
    "bisque": "rgb(255, 228, 196)",
    "blanchedalmond": "rgb(255, 235, 205)",
    "blueviolet": "rgb(138, 43, 226)",
    "brown": "rgb(165, 42, 42)",
    "burlywood": "rgb(222, 184, 135)",
    "cadetblue": "rgb(95, 158, 160)",
    "chartreuse": "rgb(127, 255, 0)",
    "chocolate": "rgb(210, 105, 30)",
    "coral": "rgb(255, 127, 80)",
    "cornflowerblue": "rgb(100, 149, 237)",
    "cornsilk": "rgb(255, 248, 220)",
    "crimson": "rgb(220, 20, 60)",
    "darkblue": "rgb(0, 0, 139)",
    "darkcyan": "rgb(0, 139, 139)",
    "darkgoldenrod": "rgb(184, 134, 11)",
    "darkgray": "rgb(169, 169, 169)",
    "darkgreen": "rgb(0, 100, 0)",
    "darkgrey": "rgb(169, 169, 169)",
    "darkkhaki": "rgb(189, 183, 107)",
    "darkmagenta": "rgb(139, 0, 139)",
    "darkolivegreen": "rgb(85, 107, 47)",
    "darkorange": "rgb(255, 140, 0)",
    "darkorchid": "rgb(153, 50, 204)",
    "darkred": "rgb(139, 0, 0)"
}

function generateColor(index: number = -1) {
    if (index < 0 || index > Object.keys(randomColorList).length) {
        // return random color from list
        let keys = Object.keys(randomColorList);
        let randomKey = keys[Math.floor(Math.random() * keys.length)];
        return randomColorList[randomKey];
    } else {
        // return color by index
        let keys = Object.keys(randomColorList);
        return randomColorList[keys[index]];
    }
}

function randomPlaceImagesInCanvas(canvasWidth: number, canvasHeight: number, images: number[][], overlapping: boolean = true) {
    let placedImages: number[][] = [];
    for (let image of images) {
        let [width, height] = image;
        let [x, y] = [Math.floor(Math.random() * (canvasWidth - width)), Math.floor(Math.random() * (canvasHeight - height))];
        let placed = false;
        for (let placedImage of placedImages) {
            let [placedImageX, placedImageY, placedImageWidth, placedImageHeight] = placedImage;
            if (overlapping || isNotOverlap(x, y, width, height, placedImageX, placedImageY, placedImageWidth, placedImageHeight)) {
                placed = true;
            }
        }
        placedImages.push([x, y, placed ? 1 : 0]);
    }
    return placedImages;
}

function getSizeBasedOnRatio(width: number, height: number, ratio: number) {
    return [width * ratio, height];
}

function cartesianProductOfArray(...arrays: any[][]) {
    return arrays.reduce((a, b) => a.flatMap((d: any) => b.map((e: any) => [d, e].flat())));
}

function rotateRectangleAndGetSize(width: number, height: number, angle: number) {
    let radians = angle * Math.PI / 180;
    let cos = Math.abs(Math.cos(radians));
    let sin = Math.abs(Math.sin(radians));
    let newWidth = Math.ceil(width * cos + height * sin);
    let newHeight = Math.ceil(height * cos + width * sin);
    return [newWidth, newHeight];
}

function concurrentlyExecutePromisesWithSize(promises: Promise<any>[], size: number): Promise<void> {
    let promisesToExecute = promises.slice(0, size);
    let promisesToWait = promises.slice(size);
    return Promise.all(promisesToExecute).then(() => {
        if (promisesToWait.length > 0) {
            return concurrentlyExecutePromisesWithSize(promisesToWait, size);
        }
    });
}

function generateRandomRgbColor() {
    return [Math.floor(Math.random() * 256), Math.floor(Math.random() * 256), Math.floor(Math.random() * 256)];
}

function getSizeOfImage(image: Image) {
    return [image.width, image.height];
}

async function makeSureFolderExists(path: string) {
    if (!fs.existsSync(path)) {
        await promises.mkdir(path, { recursive: true });
    }
}

// non repeatly select elements from array
async function randomSelectFromArray<T>(array: T[], count: number) {
    let copied = array.slice();
    let selected: T[] = [];
    for (let i = 0; i < count; i++) {
        let index = Math.floor(Math.random() * copied.length);
        selected.push(copied[index]);
        copied.splice(index, 1);
    }
    return selected;
}

function getFileNameFromPathWithoutPrefix(path: string) {
    return path.split("/").pop()!.split(".")[0];
}

type Annotion = {
    "image": string,
    "annotions": {
        "label": string,
        "coordinates": {
            "x": number,
            "y": number,
            "width": number,
            "height": number
        }
    }[]
}

async function generateCreateMLFormatOutput(folderPath: string, outputDir: string, imageCountPerFile: number = IMAGE_GENERATED_COUNT_PER_CLASS) {

    if (!fs.existsSync(path.join(folderPath, "real"))) {
        throw new Error("real folder does not exist");
    }

    let realFiles = fs.readdirSync(path.join(folderPath, "real")).map((file) => path.join(folderPath, "real", file));
    let confusionFiles: string[] = [];

    if (fs.existsSync(path.join(folderPath, "confusion"))) {
        confusionFiles = fs.readdirSync(path.join(folderPath, "confusion")).map((file) => path.join(folderPath, "confusion", file));
    }

    // getting files in folder
    let tasks: Promise<void>[] = [];
    let annotions: Annotion[] = [];

    for (let filePath of realFiles) {

        let className = getFileNameFromPathWithoutPrefix(filePath);

        for (let i = 0; i < imageCountPerFile; i++) {

            let annotion: Annotion = {
                "image": `${className}-${i}.jpg`,
                "annotions": []
            };

            async function __task(i: number) {

                let randomCount = Math.random() * MAX_NUMBER_OF_CLASSES_IN_SINGLE_IMAGE;
                randomCount = randomCount > realFiles.length + confusionFiles.length ? realFiles.length + confusionFiles.length : randomCount;
                let selectedFiles = await randomSelectFromArray(realFiles.concat(confusionFiles), randomCount);
                if (selectedFiles.includes(filePath)) {
                    // move filePath to the first
                    selectedFiles.splice(selectedFiles.indexOf(filePath), 1);
                    selectedFiles.unshift(filePath);
                } else {
                    selectedFiles.unshift(filePath);
                }

                console.log(`processing ${filePath} ${i}, selected ${selectedFiles.length} files`);

                let images = await Promise.all(selectedFiles.map(async (filePath) => {
                    let file = await promises.readFile(filePath);
                    let image = new Image();
                    image.src = file;
                    return image;
                }));

                console.log(`processing: ${filePath}, loaded images, start to place images in canvas`);

                let imageSizes = images.map(getSizeOfImage).map( x => {
                    let averageX = CanvasSize[0] / (images.length + 1);
                    let averageY = CanvasSize[1] / (images.length + 1);
                    return [x[0] > averageX ? averageX : x[0], x[1] > averageY ? averageY : x[1]];
                });

                let placedPoints = randomPlaceImagesInCanvas(CANVAS_WIDTH, CANVAS_HEIGHT, imageSizes, false);

                console.log(`processing: ${filePath}, placed images in canvas, start to draw images`);

                let angle = 0;
                let color = generateColor(i);

                let [canvasWidth, canvasHeight] = CanvasSize;
                const canvas = createCanvas(canvasWidth, canvasHeight);
                const ctx = canvas.getContext("2d");

                ctx.fillStyle = color;
                ctx.fillRect(0, 0, canvasWidth, canvasHeight);

                for (let j = 0; j < images.length; j++) {
                    const ctx = canvas.getContext("2d");

                    let ratio = Math.random() * 1.5 + 0.5;

                    let image = images[j];

                    let [_imageWidth, _imageHeight] = imageSizes[j];
                    let [imageWidth, imageHeight] = getSizeBasedOnRatio(_imageWidth, _imageHeight, ratio);

                    let placed = placedPoints[j][2] === 1 ? true : false;
                    if (!placed) {
                        continue;
                    }

                    let targetX = placedPoints[j][0] > imageWidth / 2 ? placedPoints[j][0] : imageWidth / 2;
                    let targetY = placedPoints[j][1] > imageHeight / 2 ? placedPoints[j][1] : imageHeight / 2;

                    let sizeAfterRotatation = rotateRectangleAndGetSize(imageWidth, imageHeight, angle);

                    console.log("final: ", [canvasWidth, canvasHeight], [imageWidth, imageHeight], [targetX, targetY], angle, ratio, color);

                    ctx.translate(targetX, targetY);
                    ctx.rotate(angle * Math.PI / 180);

                    ctx.drawImage(image, -imageWidth / 2, -imageHeight / 2, imageWidth, imageHeight);

                    ctx.rotate(-angle * Math.PI / 180);
                    ctx.translate(-targetX, -targetY);

                    // ctx.fillStyle = "green";
                    // ctx.strokeRect(targetX - sizeAfterRotatation[0] / 2, targetY - sizeAfterRotatation[1] / 2, sizeAfterRotatation[0], sizeAfterRotatation[1]);

                    annotion.annotions.push({
                        "label": getFileNameFromPathWithoutPrefix(selectedFiles[j]),
                        "coordinates": {
                            "x": targetX,
                            "y": targetY,
                            "width": sizeAfterRotatation[0],
                            "height": sizeAfterRotatation[1]
                        }
                    });
                }

                if (!annotion.annotions.length) {
                    return;
                }

                let fileName = path.join(outputDir, `${className}-${i}.jpg`);
                let pngData = await canvas.encode("jpeg");
                await promises.writeFile(fileName, pngData);

                annotions.push(annotion);
            }

            tasks.push(__task(i));

        }

    }

    await concurrentlyExecutePromisesWithSize(tasks, CONCURRENT_PROMISE_SIZE);

    await promises.writeFile(path.join(outputDir, "annotions.json"), JSON.stringify(annotions, null, 4));

}

async function generateYoloFormatOutput(folderPath: string) {
    const annotions = JSON.parse((await promises.readFile(path.join(folderPath, "annotions.json"))).toString("utf-8")) as Annotion[];

    // generate data.yml
    let classes: string[] = [];
    for (let annotion of annotions) {
        for (let label of annotion.annotions.map(a => a.label)) {
            if (!classes.includes(label)) {
                classes.push(label);
            }
        }
    }

    let dataYml = `
train: ./train/images
val: ./valid/images
test: ./test/images

nc: ${classes.length}
names: ${JSON.stringify(classes)}
`
    await promises.writeFile(path.join(folderPath, "data.yml"), dataYml);

    const weights = [0.85, 0.90, 0.95];
    const split = ["train", "valid", "test"];

    let tasks: Promise<void>[] = [];

    async function __task(annotion: Annotion) {
        const randomSeed = Math.random();
        let index = 0;
        for (let i = 0; i < weights.length; i++) {
            if (randomSeed < weights[i]) {
                index = i;
                break;
            }
        }
        let splitFolderName = split[index];
        await makeSureFolderExists(path.join(folderPath, splitFolderName));
        await makeSureFolderExists(path.join(folderPath, splitFolderName, "images"));
        await makeSureFolderExists(path.join(folderPath, splitFolderName, "labels"));

        // get info of image
        let image = await Sharp(path.join(folderPath, annotion.image)).metadata();

        // generate label files
        let line: [number, number, number, number, number][] = []
        for (let i of annotion.annotions) {
            line.push([
                classes.indexOf(i.label),
                i.coordinates.x / image.width!,
                i.coordinates.y / image.height!,
                i.coordinates.width / image.width!,
                i.coordinates.height / image.height!
            ])
        }

        await promises.rename(path.join(folderPath, annotion.image), path.join(folderPath, splitFolderName, "images", annotion.image));
        await promises.writeFile(path.join(folderPath, splitFolderName, "labels", annotion.image.replace(".jpg", ".txt")), line.map(l => l.join(" ")).join("\n"));
    }

    for (let annotion of annotions) {
        tasks.push(__task(annotion));
    }

    await concurrentlyExecutePromisesWithSize(tasks, CONCURRENT_PROMISE_SIZE);

}

(async () => {

    await generateCreateMLFormatOutput("./database", "./output");

    // await generateYoloFormatOutput("./output");

})();

这个脚本的思路大概是将这 40 多张图片随意揉成各种可能的形状,然后选取若干张把它撒在画布上,画布的背景也是随机的,用来模拟足够多的场景。

顺带一说,上面 500MB 这个量级并不是一下子就定好的,而是不断试验,为了更高的准确度一步一步地提高量级。

模型训练

下一步就比较简单了,在 CreateML 上选取你的数据集,然后就可以训练了:

可以看到 CreateML 的 Object Detection 其实是基于 Yolo V2 的,最先进的 Yolo 版本应该是 Yolo V7,但是生态最健全的应该还是 Yolo V5。

在我的 M1 Pro 机器上大概需要训练 10h+,在 Intel 的笔记本上训练时间会更长。整个过程有点像「炼蛊」了,从 500 多 MB 的文件算出一个 80MB 的文件。

模型测试

训练完之后,你可以得到上面「尝鲜」中得到模型文件,大概它拖动任意文件进去,就可以测试模型的效果了:

在 iOS 中使用的模型

官方的 Demo 可以参照这个例子:

https://developer.apple.com/documentation/vision/recognizing_objects_in_live_capture

个人用 SwiftUI 写了一个 Demo:

//
//  ContentView.swift
//  DemoProject
/
//

import SwiftUI
import Vision

class MyVNModel: ObservableObject {

    static let shared: MyVNModel = MyVNModel()

    @Published var parsedModel: VNCoreMLModel? = .none
    var images: [UIImage]? = .none
    var observationList: [[VNObservation]]? = .none

    func applyModelToCgImage(image: CGImage) async throws -> [VNObservation] {
        guard let parsedModel = parsedModel else {
            throw EvaluationError.resourceNotFound("cannot find parsedModel")
        }

        let resp = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<[VNObservation], Error>) in
            let requestHandler = VNImageRequestHandler(cgImage: image)
            let request = VNCoreMLRequest(model: parsedModel) { request, error in
                if let _ = error {
                    return
                }
                if let results = request.results {
                    continuation.resume(returning: results)
                } else {
                    continuation.resume(throwing: EvaluationError.invalidExpression(
                        "cannot find observations in result"
                    ))
                }
            }
            #if targetEnvironment(simulator)
                request.usesCPUOnly = true
            #endif
            do {
                // Perform the text-recognition request.
                try requestHandler.perform([request])
            } catch {
                continuation.resume(throwing: error)
            }
        }
        return resp
    }

    init() {
        Task(priority: .background) {
            let urlPath = Bundle.main.url(forResource: "genshin2", withExtension: "mlmodelc")
            guard let urlPath = urlPath else {
                print("cannot find file genshin2.mlmodelc")
                return
            }

            let config = MLModelConfiguration()
            let modelResp = await withCheckedContinuation { continuation in
                MLModel.load(contentsOf: urlPath, configuration: config) { result in
                    continuation.resume(returning: result)
                }
            }

            let model = try { () -> MLModel in
                switch modelResp {
                case let .success(m):
                    return m
                case let .failure(err):
                    throw err
                }
            }()

            let parsedModel = try VNCoreMLModel(for: model)
            DispatchQueue.main.async {
                self.parsedModel = parsedModel
            }
        }
    }

}

struct ContentView: View {

    enum SheetType: Identifiable {
        case photo
        case confirm
        var id: SheetType { self }
    }

    @State var showSheet: SheetType? = .none

    @ObservedObject var viewModel: MyVNModel = MyVNModel.shared

    var body: some View {
        VStack {
            Button {
                showSheet = .photo
            } label: {
                Text("Choose Photo")
            }
        }
        .sheet(item: $showSheet) { sheetType in
            switch sheetType {
            case .photo:
                PhotoLibrary(handlePickedImage: { images in

                    guard let images = images else {
                        print("no images is selected")
                        return
                    }

                    var observationList: [[VNObservation]] = []
                    Task {
                        for image in images {

                            guard let cgImage = image.cgImage else {
                                throw EvaluationError.cgImageRetrievalFailure
                            }

                            let result = try await viewModel.applyModelToCgImage(image: cgImage)
                            print("model applied: (result)")

                            observationList.append(result)
                        }

                        DispatchQueue.main.async {
                            viewModel.images = images
                            viewModel.observationList = observationList
                            self.showSheet = .confirm
                        }
                    }

                }, selectionLimit: 1)
            case .confirm:
                if let images = viewModel.images, let observationList = viewModel.observationList {
                    VNObservationConfirmer(imageList: images, observations: observationList, onSubmit: { _,_  in

                    })
                } else {
                    Text("No Images (viewModel.images?.count ?? 0) (viewModel.observationList?.count ?? 0)")
                }

            }

        }
        .padding()
    }
}

struct ContentView_Previews: PreviewProvider {
    static var previews: some View {
        ContentView()
    }
}

运行效果

Copyright© 2013-2020

All Rights Reserved 京ICP备2023019179号-8