Android studio项目加载pytorch模型文件

文章讲述了如何将YOLO模型转换为torchscript格式并集成到Android应用中,通过添加PyTorch和torchvision的依赖库,实现模型的加载。接着,展示了如何进行目标检测的推理过程,包括非极大值抑制算法的实现,以过滤预测结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.首先把你YOLO模型转为torchscript格式。

2.然后把模型文件放在你的安卓项目的资源文件下【需要添加标签文件(如下图)】

3.添加依赖库【Gradle app文件下】 

    implementation 'org.pytorch:pytorch_android_lite:1.9.0'
    implementation 'org.pytorch:pytorch_android_torchvision_lite:1.9.0'

4.编写加载模型的代码类【不详细解释-直接放代码】

        4.1【推理】

public class PrePostProcessor {
    public static float[] NO_MEAN_RGB = new float[] {0.0f, 0.0f, 0.0f};
    public static float[] NO_STD_RGB = new float[] {1.0f, 1.0f, 1.0f};
    public static int mInputWidth = 320;
    public static int mInputHeight = 320;

    private static final int mOutputRow = 6300; 
    private static final int mOutputColumn = 9; 
    private static final float mThreshold = 0.35f;
    private static final int mNmsLimit = 5;

    public static  String[] mClasses;

    static float IOU(Rect a, Rect b){
        float areaA = (a.right - a.left) * (a.bottom - a.top);
        if (areaA <= 0.0) return 0.0f;

        float areaB = (b.right - b.left) * (b.bottom - b.top);
        if (areaB <= 0.0) return 0.0f;

        float intersectionMinX = Math.max(a.left, b.left);
        float intersectionMinY = Math.max(a.top, b.top);
        float intersectionMaxX = Math.min(a.right, b.right);
        float intersectionMaxY = Math.min(a.bottom, b.bottom);
        float intersectionArea = Math.max(intersectionMaxY - intersectionMinY, 0 )*
                Math.max(intersectionMaxX - intersectionMinX, 0);
        return intersectionArea / (areaA + areaB - intersectionArea);
    }

    static ArrayList<ResultCAR> nonMaxSuppression(ArrayList<ResultCAR> boxes, int limit, float threshold){
        Collections.sort(boxes,
                new Comparator<ResultCAR>(){
                    @Override
                    public int compare(ResultCAR o1, ResultCAR o2){
                        return o1.score.compareTo(o2.score);
                    }
                });
        ArrayList<ResultCAR> selected = new ArrayList<>();
        boolean[] active = new boolean[boxes.size()];
        Arrays.fill(active, true);
        int numActive = active.length;

        boolean done = false;
        for (int i=0; i<boxes.size() && !done; i++){
            if (active[i]){
                ResultCAR boxA = boxes.get(i);
                selected.add(boxA);
                if (selected.size() >= limit) break;

                for(int j = i+1; j<boxes.size();j++){
                    if(active[j]){
                        ResultCAR boxB = boxes.get(j);
                        if (IOU(boxA.raw_rect, boxB.raw_rect)>threshold){
                            active[j] = false;
                            numActive -= 1;
                            if  (numActive <= 0){
                                done = true;
                                break;
                            }
                        }
                    }
                }
            }
        }
        return selected;
    }

    public static ArrayList<ResultCAR> outputsToNMSPredictions(float[] outputs, float imgScaleX, float imgScaleY, float ivScaleX, float ivScaleY,float startX, float startY){
        ArrayList<ResultCAR> results = new ArrayList<>();
        for (int i=0; i<mOutputRow; i++){
            if (outputs[i* mOutputColumn +4]>mThreshold){
                float x = outputs[i* mOutputColumn];
                float y = outputs[i* mOutputColumn +1];
                float w = outputs[i* mOutputColumn +2];
                float h = outputs[i* mOutputColumn +3];

                float left = imgScaleX * (x - w/2);
                float top = imgScaleY * (y-h/2);
                float right = imgScaleX * (x + w/2);
                float bottom = imgScaleY * (y + h/2);

                float max = outputs[i* mOutputColumn +5];
                int cls = 0;
                for (int j=0; j<mOutputColumn-5;j++){
                    if (outputs[i* mOutputColumn +5+j] > max){
                        max = outputs[i * mOutputColumn +5+j];
                        cls = j;
                    }
                }
                Rect rect = new Rect((int)(startX + ivScaleX*left),(int)(startY+top*ivScaleY),
                        (int)(startX+ivScaleX*right), (int) (startY+ivScaleY*bottom));
                ResultCAR result = new ResultCAR(cls, outputs[i * mOutputColumn+4], rect);
                results.add(result);
            }
        }
        return nonMaxSuppression(results, mNmsLimit, mThreshold);
    }

         4.2【获取推理结果】

    private static float  mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY;
    public  static Bitmap resimg  = null;
    public static String runCAR(Bitmap mBitmap, Module mModuleCarTag , Context context, boolean isSaveImage) {
        Bitmap corpBitmap = null;
        String resulthld = null;
        mImgScaleX = (float) mBitmap.getWidth() / PrePostProcessor.mInputWidth;          
        mImgScaleY = (float) mBitmap.getHeight() / PrePostProcessorCar.mInputHeight;
        mIvScaleX = (mBitmap.getWidth() > mBitmap.getHeight() ? (float) 1 / mBitmap.getWidth() : (float) 1 / mBitmap.getHeight());
        mIvScaleY = (mBitmap.getHeight() > mBitmap.getWidth() ? (float) 1 / mBitmap.getHeight() : (float) 1 / mBitmap.getWidth());
        mStartX = (1 - mIvScaleX * mBitmap.getWidth()) / 2;
        mStartY = (1 - mIvScaleY * mBitmap.getHeight()) / 2;
        // 缩放Bitmap
        Bitmap resizedBitmap = Bitmap.createScaledBitmap(mBitmap, PrePostProcessorCar.mInputWidth, PrePostProcessorCar.mInputHeight, true);
        // Bitmap -> Tensor
        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, PrePostProcessorCar.NO_MEAN_RGB, PrePostProcessorCar.NO_STD_RGB);    
        IValue[] outputTuple = mModuleCarTag.forward(IValue.from(inputTensor)).toTuple();     
        final Tensor outputTensor = outputTuple[0].toTensor();                          
        final float[] outputs = outputTensor.getDataAsFloatArray();                     
        final ArrayList<ResultCAR> results = PrePostProcessorCar.outputsToNMSPredictions(outputs, mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY);   // 非极大值抑制
        Set<Integer> set = new HashSet<>();
        for (int i = 0; i < results.size(); i++) {
            set.add(results.get(i).classIndex);
            Log.e("置信度:", results.get(i).score + "");
            if(isSaveImage) {
                Rect rect = results.get(i).raw_rect;
                corpBitmap = Bitmap.createBitmap(mBitmap, rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top);
                Mat mat = new Mat();
                Utils.bitmapToMat(corpBitmap, mat);
                Imgproc.putText(mat, PrePostProcessor.mClasses[results.get(i).classIndex], new org.opencv.core.Point(10, 10), 1, 1, new org.opencv.core.Scalar(0, 0, 255), 1);
                corpBitmap = ImageUtils.mat2Bitmap(mat);
                resimg = corpBitmap;
                SaveBitmap.saveImageToGallery(context, corpBitmap);
            }
        }
        List<Integer> list = new ArrayList<>(set);
        Collections.sort(list);

            for (int i = 0; i < list.size(); i++) {
                    Log.e("识别结果", PrePostProcessor.mClasses[list.get(i)]);
                if(list.size()>=1) {
                    resulthld = PrePostProcessor.mClasses[list.get(i)];
                }

            }

        return resulthld;
    }

        4.3【补充Result】

public class ResultCAR {
    public int classIndex;
    public Float score;
    public Rect rect;
    public Rect raw_rect;

    public ResultCAR(int cls, Float output, Rect rect,Rect raw_rect){
        this.classIndex = cls;
        this.score = output;
        this.rect = rect;
        this.raw_rect = raw_rect;
    }
}

5.其他UI代码就自己搞定吧,ResultCAR中rect存放的就是目标的坐标。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值