【实操干货】创建一个用在图像内部进行对象检测的Android应用程序
- 2019 年 11 月 20 日
- 筆記

在移动设备上运行机器学习代码是下一件大事。 PyTorch在最新版本的PyTorch 1.3中添加了PyTorch Mobile,用于在Android和iOS设备上部署机器学习模型。
在这里,我们将研究创建一个用于在图像内部进行对象检测的Android应用程序;如下图所示。

应用程序的演示运行

步骤1:准备模型
在本教程中,我们将使用经过预训练好的ResNet18模型。ResNet18是具有1000个分类类别的最先进的计算机视觉模型。
1.安装Torchvision库
pip install torchvision
2.下载并跟踪ResNet18模型
我们追踪这个模型是因为我们需要一个可执行的ScriptModule来进行即时编译。
import torch import torchvision resnet18 = torchvision.models.resnet18(pretrained=True) resnet18.eval() example_inputs = torch.rand(1, 3, 224, 224) resnet18_traced = torch.jit.trace(resnet18, example_inputs = example_inputs) resnet18_traced.save("resnet18_traced.pt")
注意:
- 将resnet18_traced.pt存储在一个已知的位置,在本教程的后续步骤中我们将需要此位置。
- 在torch.rand中,我们采用了224 * 224的尺寸,因为ResNet18接受224 * 224的尺寸。

步骤2:制作Android应用程序
1.如果尚未安装,请下载并安装Android Studio,如果是,请单击“是”以下载和安装SDK。链接:https://developer.android.com/studio
2.打开Android Studio,然后单击:启动一个新的Android Studio项目
3.选择清空活动

4.输入应用程序名称:ObjectDetectorDemo,然后按Finish

5.安装NDK运行Android内部运行原生代码:
- 转到Tools> SDK Manager
- 单击SDK工具
- 选中NDK(并排)旁边的框

6.添加依赖项
Insidebuild.gradle(Module:app)。
在依赖项中添加以下内容
dependencies { implementation fileTree(dir: 'libs', include: ['*.jar']) implementation 'androidx.appcompat:appcompat:1.0.2' implementation 'androidx.constraintlayout:constraintlayout:1.1.3' implementation 'org.pytorch:pytorch_android:1.3.0' implementation 'org.pytorch:pytorch_android_torchvision:1.3.0' }
7.添加基本布局以加载图像并显示结果
转到app> res> layout> activity_main.xml,然后添加以下代码
<ImageView android:id="@+id/image" app:layout_constraintTop_toTopOf="parent" android:layout_width="match_parent" android:layout_height="400dp" android:layout_marginBottom="20dp" android:scaleType="fitCenter" /> <TextView android:id="@+id/result_text" android:layout_width="match_parent" android:layout_height="wrap_content" android:layout_gravity="top" android:text="" android:textSize="20dp" android:textStyle="bold" android:textAllCaps="true" android:textAlignment="center" app:layout_constraintTop_toTopOf="@id/button" app:layout_constraintBottom_toBottomOf="@+id/image" /> <Button android:id="@+id/button" android:layout_width="match_parent" android:layout_height="wrap_content" android:text="Load Image" app:layout_constraintBottom_toBottomOf="@+id/result_text" app:layout_constraintTop_toTopOf="@+id/detect" /> <Button android:id="@+id/detect" android:layout_width="match_parent" android:layout_height="wrap_content" android:text="Detect" android:layout_marginBottom="50dp" app:layout_constraintBottom_toBottomOf="parent" />
您的布局应如下图所示

8.我们需要设置权限以读取设备上的图像存储
转到app> manifests> AndroidManifest.xml,然后在manifest标签内添加以下代码
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
获取应用程序加载权限(仅在您授予权限之前询问)
—转到Main Activity java。在onCreate()方法中添加以下代码。
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { requestPermissions(new String[] {android.Manifest.permission.READ_EXTERNAL_STORAGE}, 1); }
9.复制模型
现在是时候复制使用python脚本创建的模型了。
从文件资源管理器/查找器中打开您的应用程序。
转到app > src > main。
创建一个名为assets的文件夹将模型复制到此文件夹中。打开后,您将在Android Studio中看到如下图所示。(如果没有,请右键单击应用程序文件夹,然后单击“同步应用程序”)

10.我们需要列出模型的输出类
转到app > java
在第一个文件夹中,将新的Java类名称命名为ModelClasses。
将类的列表定义为(整个列表为1000个类,因此可以在此处复制所有内容(检查Json或Git)以获取完整列表,然后在下面的列表内复制):
public static String[] MODEL_CLASSES = new String[]{ "tench, Tinca tinca", "goldfish, Carassius auratus" . . . }
11.Main Activity Java,这里将定义按钮动作,读取图像并调用PyTorch模型。请参阅代码内的注释以获取解释。
package com.tckmpsi.objectdetectordemo; import androidx.appcompat.app.AppCompatActivity; import android.content.Context; import android.content.Intent; import android.database.Cursor; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.graphics.drawable.BitmapDrawable; import android.net.Uri; import android.os.Build; import android.os.Bundle; import android.provider.MediaStore; import android.view.View; import android.widget.Button; import android.widget.ImageView; import android.widget.TextView; import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.Tensor; import org.pytorch.torchvision.TensorImageUtils; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; public class MainActivity extends AppCompatActivity { private static int RESULT_LOAD_IMAGE = 1; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); Button buttonLoadImage = (Button) findViewById(R.id.button); Button detectButton = (Button) findViewById(R.id.detect); if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { requestPermissions(new String[]{android.Manifest.permission.READ_EXTERNAL_STORAGE}, 1); } buttonLoadImage.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View arg0) { TextView textView = findViewById(R.id.result_text); textView.setText(""); Intent i = new Intent( Intent.ACTION_PICK, MediaStore.Images.Media.EXTERNAL_CONTENT_URI); startActivityForResult(i, RESULT_LOAD_IMAGE); } }); detectButton.setOnClickListener(new View.OnClickListener() { @Override public void onClick(View arg0) { Bitmap bitmap = null; Module module = null; //Getting the image from the image view ImageView imageView = (ImageView) findViewById(R.id.image); try { //Read the image as Bitmap bitmap = ((BitmapDrawable)imageView.getDrawable()).getBitmap(); //Here we reshape the image into 400*400 bitmap = Bitmap.createScaledBitmap(bitmap, 400, 400, true); //Loading the model file. module = Module.load(fetchModelFile(MainActivity.this, "resnet18_traced.pt")); } catch (IOException e) { finish(); } //Input Tensor final Tensor input = TensorImageUtils.bitmapToFloat32Tensor( bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB ); //Calling the forward of the model to run our input final Tensor output = module.forward(IValue.from(input)).toTensor(); final float[] score_arr = output.getDataAsFloatArray(); // Fetch the index of the value with maximum score float max_score = -Float.MAX_VALUE; int ms_ix = -1; for (int i = 0; i < score_arr.length; i++) { if (score_arr[i] > max_score) { max_score = score_arr[i]; ms_ix = i; } } //Fetching the name from the list based on the index String detected_class = ModelClasses.MODEL_CLASSES[ms_ix]; //Writing the detected class in to the text view of the layout TextView textView = findViewById(R.id.result_text); textView.setText(detected_class); } }); } @Override protected void onActivityResult(int requestCode, int resultCode, Intent data) { //This functions return the selected image from gallery super.onActivityResult(requestCode, resultCode, data); if (requestCode == RESULT_LOAD_IMAGE && resultCode == RESULT_OK && null != data) { Uri selectedImage = data.getData(); String[] filePathColumn = { MediaStore.Images.Media.DATA }; Cursor cursor = getContentResolver().query(selectedImage, filePathColumn, null, null, null); cursor.moveToFirst(); int columnIndex = cursor.getColumnIndex(filePathColumn[0]); String picturePath = cursor.getString(columnIndex); cursor.close(); ImageView imageView = (ImageView) findViewById(R.id.image); imageView.setImageBitmap(BitmapFactory.decodeFile(picturePath)); //Setting the URI so we can read the Bitmap from the image imageView.setImageURI(null); imageView.setImageURI(selectedImage); } } public static String fetchModelFile(Context context, String modelName) throws IOException { File file = new File(context.getFilesDir(), modelName); if (file.exists() && file.length() > 0) { return file.getAbsolutePath(); } try (InputStream is = context.getAssets().open(modelName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } } }
12.现在是时候测试应用程序了。两种方法有两种:
- 在模拟器上运行(https://developer.android.com/studio/run/emulator)。
- 使用Android设备。(为此,您需要启用USB调试(http://developer.android.com/studio/run/emulator))。
- 运行应用程序后,它的外观应类似于页面顶部的GIF。
链接到Git存储库:https://github.com/tusharck/Object-Detector-Android-App-Using-PyTorch-Mobile-Neural-Network
好看的人才能点