【實操乾貨】創建一個用在影像內部進行對象檢測的Android應用程式

  • 2019 年 11 月 20 日
  • 筆記

在移動設備上運行機器學習程式碼是下一件大事。 PyTorch在最新版本的PyTorch 1.3中添加了PyTorch Mobile,用於在Android和iOS設備上部署機器學習模型。

在這裡,我們將研究創建一個用於在影像內部進行對象檢測的Android應用程式;如下圖所示。

應用程式的演示運行

步驟1:準備模型

在本教程中,我們將使用經過預訓練好的ResNet18模型。ResNet18是具有1000個分類類別的最先進的電腦視覺模型。

1.安裝Torchvision庫

pip install torchvision

2.下載並跟蹤ResNet18模型

我們追蹤這個模型是因為我們需要一個可執行的ScriptModule來進行即時編譯。

import torch  import torchvision  resnet18 = torchvision.models.resnet18(pretrained=True)  resnet18.eval()  example_inputs = torch.rand(1, 3, 224, 224)  resnet18_traced = torch.jit.trace(resnet18, example_inputs = example_inputs)  resnet18_traced.save("resnet18_traced.pt")

注意:

  1. 將resnet18_traced.pt存儲在一個已知的位置,在本教程的後續步驟中我們將需要此位置。
  2. 在torch.rand中,我們採用了224 * 224的尺寸,因為ResNet18接受224 * 224的尺寸。

步驟2:製作Android應用程式

1.如果尚未安裝,請下載並安裝Android Studio,如果是,請單擊「是」以下載和安裝SDK。鏈接:https://developer.android.com/studio

2.打開Android Studio,然後單擊:啟動一個新的Android Studio項目

3.選擇清空活動

4.輸入應用程式名稱:ObjectDetectorDemo,然後按Finish

5.安裝NDK運行Android內部運行原生程式碼:

  • 轉到Tools> SDK Manager
  • 單擊SDK工具
  • 選中NDK(並排)旁邊的框

6.添加依賴項

Insidebuild.gradle(Module:app)。

在依賴項中添加以下內容

dependencies {      implementation fileTree(dir: 'libs', include: ['*.jar'])      implementation 'androidx.appcompat:appcompat:1.0.2'      implementation 'androidx.constraintlayout:constraintlayout:1.1.3'  implementation 'org.pytorch:pytorch_android:1.3.0'      implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'  }

7.添加基本布局以載入影像並顯示結果

轉到app> res> layout> activity_main.xml,然後添加以下程式碼

<ImageView      android:id="@+id/image"      app:layout_constraintTop_toTopOf="parent"      android:layout_width="match_parent"      android:layout_height="400dp"      android:layout_marginBottom="20dp"      android:scaleType="fitCenter" />    <TextView      android:id="@+id/result_text"      android:layout_width="match_parent"      android:layout_height="wrap_content"      android:layout_gravity="top"      android:text=""      android:textSize="20dp"      android:textStyle="bold"      android:textAllCaps="true"      android:textAlignment="center"      app:layout_constraintTop_toTopOf="@id/button"      app:layout_constraintBottom_toBottomOf="@+id/image" />    <Button      android:id="@+id/button"      android:layout_width="match_parent"      android:layout_height="wrap_content"      android:text="Load Image"      app:layout_constraintBottom_toBottomOf="@+id/result_text"      app:layout_constraintTop_toTopOf="@+id/detect" />    <Button      android:id="@+id/detect"      android:layout_width="match_parent"      android:layout_height="wrap_content"      android:text="Detect"      android:layout_marginBottom="50dp"      app:layout_constraintBottom_toBottomOf="parent" />

您的布局應如下圖所示

8.我們需要設置許可權以讀取設備上的影像存儲

轉到app> manifests> AndroidManifest.xml,然後在manifest標籤內添加以下程式碼

<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>

獲取應用程式載入許可權(僅在您授予許可權之前詢問)

—轉到Main Activity java。在onCreate()方法中添加以下程式碼。

if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {      requestPermissions(new String[]  {android.Manifest.permission.READ_EXTERNAL_STORAGE}, 1);  }

9.複製模型

現在是時候複製使用python腳本創建的模型了。

從文件資源管理器/查找器中打開您的應用程式。

轉到app > src > main

創建一個名為assets的文件夾將模型複製到此文件夾中。打開後,您將在Android Studio中看到如下圖所示。(如果沒有,請右鍵單擊應用程式文件夾,然後單擊「同步應用程式」)

10.我們需要列出模型的輸出類

轉到app > java

在第一個文件夾中,將新的Java類名稱命名為ModelClasses。

將類的列表定義為(整個列表為1000個類,因此可以在此處複製所有內容(檢查Json或Git)以獲取完整列表,然後在下面的列表內複製):

public static String[] MODEL_CLASSES = new String[]{          "tench, Tinca tinca",          "goldfish, Carassius auratus"          .          .          .  }

11.Main Activity Java,這裡將定義按鈕動作,讀取影像並調用PyTorch模型。請參閱程式碼內的注釋以獲取解釋。

package com.tckmpsi.objectdetectordemo;    import androidx.appcompat.app.AppCompatActivity;    import android.content.Context;  import android.content.Intent;  import android.database.Cursor;  import android.graphics.Bitmap;  import android.graphics.BitmapFactory;  import android.graphics.drawable.BitmapDrawable;  import android.net.Uri;  import android.os.Build;  import android.os.Bundle;  import android.provider.MediaStore;  import android.view.View;  import android.widget.Button;  import android.widget.ImageView;  import android.widget.TextView;    import org.pytorch.IValue;  import org.pytorch.Module;  import org.pytorch.Tensor;  import org.pytorch.torchvision.TensorImageUtils;    import java.io.File;  import java.io.FileOutputStream;  import java.io.IOException;  import java.io.InputStream;  import java.io.OutputStream;    public class MainActivity extends AppCompatActivity {      private static int RESULT_LOAD_IMAGE = 1;        @Override      protected void onCreate(Bundle savedInstanceState) {          super.onCreate(savedInstanceState);          setContentView(R.layout.activity_main);            Button buttonLoadImage = (Button) findViewById(R.id.button);          Button detectButton = (Button) findViewById(R.id.detect);              if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {              requestPermissions(new String[]{android.Manifest.permission.READ_EXTERNAL_STORAGE}, 1);          }          buttonLoadImage.setOnClickListener(new View.OnClickListener() {                @Override              public void onClick(View arg0) {                  TextView textView = findViewById(R.id.result_text);                  textView.setText("");                  Intent i = new Intent(                          Intent.ACTION_PICK,                          MediaStore.Images.Media.EXTERNAL_CONTENT_URI);                    startActivityForResult(i, RESULT_LOAD_IMAGE);                  }          });            detectButton.setOnClickListener(new View.OnClickListener() {                @Override              public void onClick(View arg0) {                    Bitmap bitmap = null;                  Module module = null;                    //Getting the image from the image view                  ImageView imageView = (ImageView) findViewById(R.id.image);                    try {                      //Read the image as Bitmap                      bitmap = ((BitmapDrawable)imageView.getDrawable()).getBitmap();                        //Here we reshape the image into 400*400                      bitmap = Bitmap.createScaledBitmap(bitmap, 400, 400, true);                        //Loading the model file.                      module = Module.load(fetchModelFile(MainActivity.this, "resnet18_traced.pt"));                  } catch (IOException e) {                      finish();                  }                    //Input Tensor                  final Tensor input = TensorImageUtils.bitmapToFloat32Tensor(                          bitmap,                          TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,                          TensorImageUtils.TORCHVISION_NORM_STD_RGB                  );                    //Calling the forward of the model to run our input                  final Tensor output = module.forward(IValue.from(input)).toTensor();                      final float[] score_arr = output.getDataAsFloatArray();                    // Fetch the index of the value with maximum score                  float max_score = -Float.MAX_VALUE;                  int ms_ix = -1;                  for (int i = 0; i < score_arr.length; i++) {                      if (score_arr[i] > max_score) {                          max_score = score_arr[i];                          ms_ix = i;                      }                  }                    //Fetching the name from the list based on the index                  String detected_class = ModelClasses.MODEL_CLASSES[ms_ix];                    //Writing the detected class in to the text view of the layout                  TextView textView = findViewById(R.id.result_text);                  textView.setText(detected_class);                  }          });        }      @Override      protected void onActivityResult(int requestCode, int resultCode, Intent data) {          //This functions return the selected image from gallery          super.onActivityResult(requestCode, resultCode, data);            if (requestCode == RESULT_LOAD_IMAGE && resultCode == RESULT_OK && null != data) {              Uri selectedImage = data.getData();              String[] filePathColumn = { MediaStore.Images.Media.DATA };                Cursor cursor = getContentResolver().query(selectedImage,                      filePathColumn, null, null, null);              cursor.moveToFirst();                int columnIndex = cursor.getColumnIndex(filePathColumn[0]);              String picturePath = cursor.getString(columnIndex);              cursor.close();                ImageView imageView = (ImageView) findViewById(R.id.image);              imageView.setImageBitmap(BitmapFactory.decodeFile(picturePath));                //Setting the URI so we can read the Bitmap from the image              imageView.setImageURI(null);              imageView.setImageURI(selectedImage);              }          }        public static String fetchModelFile(Context context, String modelName) throws IOException {          File file = new File(context.getFilesDir(), modelName);          if (file.exists() && file.length() > 0) {              return file.getAbsolutePath();          }            try (InputStream is = context.getAssets().open(modelName)) {              try (OutputStream os = new FileOutputStream(file)) {                  byte[] buffer = new byte[4 * 1024];                  int read;                  while ((read = is.read(buffer)) != -1) {                      os.write(buffer, 0, read);                  }                  os.flush();              }              return file.getAbsolutePath();          }      }    }

12.現在是時候測試應用程式了。兩種方法有兩種:

  • 在模擬器上運行(https://developer.android.com/studio/run/emulator)。
  • 使用Android設備。(為此,您需要啟用USB調試(http://developer.android.com/studio/run/emulator))。
  • 運行應用程式後,它的外觀應類似於頁面頂部的GIF。

鏈接到Git存儲庫:https://github.com/tusharck/Object-Detector-Android-App-Using-PyTorch-Mobile-Neural-Network

好看的人才能點