一、概述
Deep Java Library (DJL) 是一个用于深度学习的开源、高级、引擎无关的 Java 框架。DJL 被设计成易于入门,并且对于 Java 开发人员来说易于使用。DJL 提供了与其他常规 Java 库一样的本地 Java 开发体验和函数。
你不必成为机器学习/深度学习的专家才能开始。您可以将现有的 Java 专业知识用作学习和使用机器学习和深度学习的入口。您可以使用您喜欢的 IDE 来构建、训练和部署模型。DJL 使得将这些模型与 Java 应用程序集成变得非常容易。
因为 DJL 是深度学习引擎不可知论者,所以在创建项目时不必在引擎之间做出选择。你可以在任何时候切换引擎。为了确保最佳性能,DJL 还提供基于硬件配置的自动 CPU/GPU 选择。
环境及主要软件版本说明
- java8+
- springboot 版本 2.7.16-SNAPSHOT
- djl 版本 0.23.0
二、项目构建运行
下载测试图片数据集
下载 ut-zap50k-images-square 图片集 https://vision.cs.utexas.edu/projects/finegrained/utzap50k/
创建 springboot 项目
pom 关键依赖信息
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<!-- DJL -->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- pytorch-engine-->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<scope>runtime</scope>
</dependency>
<!-- knife4j -->
<dependency>
<groupId>com.github.xiaoymin</groupId>
<artifactId>knife4j-micro-spring-boot-starter</artifactId>
<version>3.0.2</version>
</dependency>
<dependency>
<groupId>com.github.xiaoymin</groupId>
<artifactId>knife4j-spring-boot-starter</artifactId>
<version>3.0.2</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>windows</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<dependencies>
<!-- Windows CPU -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>win-x86_64</classifier>
<scope>runtime</scope>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.1-0.23.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>centos7</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<!-- For Pre-CXX11 build (CentOS7)-->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu-precxx11</artifactId>
<classifier>linux-x86_64</classifier>
<version>2.0.1</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.1-0.23.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>linux</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<!-- Linux CPU -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<classifier>linux-x86_64</classifier>
<scope>runtime</scope>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.1-0.23.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>aarch64</id>
<activation>
<activeByDefault>false</activeByDefault>
</activation>
<dependencies>
<!-- For aarch64 build-->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu-precxx11</artifactId>
<classifier>linux-aarch64</classifier>
<scope>runtime</scope>
<version>2.0.1</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>2.0.1-0.23.0</version>
<scope>runtime</scope>
</dependency>
</dependencies>
</profile>
</profiles>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>0.23.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
application.yml
server:
port: 8888
spring:
application:
name: djl-image-classification-demo
servlet:
multipart:
max-file-size: 100MB
max-request-size: 100MB
mvc:
pathmatch:
matching-strategy: ant_path_matcher
knife4j:
enable: true
djl:
num-of-output: 4
Models.java
package com.example.djl.demo;
import ai.djl.Model;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import java.io.IOException;
import java.io.Writer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
/**
* A helper class loads and saves model.
* @author lakudouzi
*/
public final class Models {
// the height and width for pre-processing of the image
public static final int IMAGE_HEIGHT = 100;
public static final int IMAGE_WIDTH = 100;
// the name of the model
public static final String MODEL_NAME = "shoeclassifier";
private Models() {
}
public static Model getModel(int numOfOutput) {
// create new instance of an empty model
Model model = Model.newInstance(MODEL_NAME);
// Block is a composable unit that forms a neural network; combine them like Lego blocks
// to form a complex network
Block resNet50 =
ResNetV1.builder() // construct the network
.setImageShape(new Shape(3, IMAGE_HEIGHT, IMAGE_WIDTH))
.setNumLayers(50)
.setOutSize(numOfOutput)
.build();
// set the neural network to the model
model.setBlock(resNet50);
return model;
}
public static void saveSynset(Path modelDir, List<String> synset) throws IOException {
Path synsetFile = modelDir.resolve("synset.txt");
try (Writer writer = Files.newBufferedWriter(synsetFile)) {
writer.write(String.join("\n", synset));
}
}
}
ImageClassificationServiceImpl.java
package com.example.djl.demo.service.impl;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.ImageFolder;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.*;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import com.example.djl.demo.Models;
import com.example.djl.demo.service.ImageClassificationService;
import lombok.Cleanup;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
* @author lakudouzi
*/
@Slf4j
@Service
public class ImageClassificationServiceImpl implements ImageClassificationService {
// represents number of training samples processed before the model is updated
private static final int BATCH_SIZE = 32;
// the number of passes over the complete dataset
private static final int EPOCHS = 2;
//the number of classification labels: boots, sandals, shoes, slippers
@Value("${djl.num-of-output:4}")
public int numOfOutput;
@Override
public String predict(MultipartFile image, String modePath) throws IOException, MalformedModelException, TranslateException {
@Cleanup
InputStream is = image.getInputStream();
Path modelDir = Paths.get(modePath);
BufferedImage bi = ImageIO.read(is);
Image img = ImageFactory.getInstance().fromImage(bi);
// empty model instance
try (Model model = Models.getModel(numOfOutput)) {
// load the model
model.load(modelDir, Models.MODEL_NAME);
// define a translator for pre and post processing
// out of the box this translator converts images to ResNet friendly ResNet 18 shape
Translator<Image, Classifications> translator =
ImageClassificationTranslator.builder()
.addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
.addTransform(new ToTensor())
.optApplySoftmax(true)
.build();
// run the inference using a Predictor
try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
// holds the probability score per label
Classifications predictResult = predictor.predict(img);
log.info("结果={}",predictResult.toJson());
return predictResult.toJson();
}
}
}
@Override
public String training(String datasetRoot, String modePath) throws TranslateException, IOException {
log.info("图片数据集训练开始......图片数据集地址路径:{}",datasetRoot);
// the location to save the model
Path modelDir = Paths.get(modePath);
// create ImageFolder dataset from directory
ImageFolder dataset = initDataset(datasetRoot);
// Split the dataset set into training dataset and validate dataset
RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);
// set loss function, which seeks to minimize errors
// loss function evaluates model's predictions against the correct answer (during training)
// higher numbers are bad - means model performed poorly; indicates more errors; want to
// minimize errors (loss)
Loss loss = Loss.softmaxCrossEntropyLoss();
// setting training parameters (ie hyperparameters)
TrainingConfig config = setupTrainingConfig(loss);
try (Model model = Models.getModel(numOfOutput); // empty model instance to hold patterns
Trainer trainer = model.newTrainer(config)) {
// metrics collect and report key performance indicators, like accuracy
trainer.setMetrics(new Metrics());
Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT);
// initialize trainer with proper input shape
trainer.initialize(inputShape);
// find the patterns in data
EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]);
// set model properties
TrainingResult result = trainer.getTrainingResult();
model.setProperty("Epoch", String.valueOf(EPOCHS));
model.setProperty(
"Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy")));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
// save the model after done training for inference later
// model saved as shoeclassifier-0000.params
model.save(modelDir, Models.MODEL_NAME);
// save labels into model directory
Models.saveSynset(modelDir, dataset.getSynset());
log.info("图片数据集训练结束......");
return String.join("\n", dataset.getSynset());
}
}
private ImageFolder initDataset(String datasetRoot)
throws IOException, TranslateException {
ImageFolder dataset =
ImageFolder.builder()
// retrieve the data
.setRepositoryPath(Paths.get(datasetRoot))
.optMaxDepth(10)
.addTransform(new Resize(Models.IMAGE_WIDTH, Models.IMAGE_HEIGHT))
.addTransform(new ToTensor())
// random sampling; don't process the data in order
.setSampling(BATCH_SIZE, true)
.build();
dataset.prepare();
return dataset;
}
private TrainingConfig setupTrainingConfig(Loss loss) {
return new DefaultTrainingConfig(loss)
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
}
}
ImageClassificationController.java
package com.example.djl.demo.rest;
import ai.djl.MalformedModelException;
import ai.djl.translate.TranslateException;
import com.example.djl.demo.service.ImageClassificationService;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;
/**
* @author lakudouzi
*/
@RestController
@RequiredArgsConstructor
public class ImageClassificationController {
private final ImageClassificationService imageClassificationService;
@PostMapping(path = "/analyze")
@ApiOperation("图片对象识别-上传图片识别")
public String predict(@RequestPart("image") MultipartFile image,
@RequestParam(defaultValue = "/home/djl-test/models") String modePath)
throws TranslateException,
MalformedModelException,
IOException {
return imageClassificationService.predict(image, modePath);
}
@PostMapping(path = "/training")
@ApiOperation("图片对象识别-训练图片数据集")
public String training(@RequestParam(defaultValue = "/home/djl-test/images-test")
String datasetRoot,
@RequestParam(defaultValue = "/home/djl-test/models") String modePath) throws TranslateException, IOException {
return imageClassificationService.training(datasetRoot, modePath);
}
@GetMapping("/download")
@ApiOperation(value = "图片对象识别-下载测试图片", produces = "application/octet-stream")
public ResponseEntity<Resource> downloadFile(@RequestParam(defaultValue = "/home/djl-test/images-test") String directoryPath) {
List<String> imgPathList = new ArrayList<>();
try (Stream<Path> paths = Files.walk(Paths.get(directoryPath))) {
// Filter only regular files (excluding directories)
paths.filter(Files::isRegularFile)
.forEach(c-> imgPathList.add(c.toString()));
} catch (IOException e) {
return ResponseEntity.status(500).build();
}
Random random = new Random();
String filePath = imgPathList.get(random.nextInt(imgPathList.size()));
Path file = Paths.get(filePath);
Resource resource = new FileSystemResource(file.toFile());
if (!resource.exists()) {
return ResponseEntity.notFound().build();
}
HttpHeaders headers = new HttpHeaders();
headers.add(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=" + file.getFileName().toString());
headers.add(HttpHeaders.CONTENT_TYPE, MediaType.IMAGE_JPEG_VALUE);
try {
return ResponseEntity.ok()
.headers(headers)
.contentLength(resource.contentLength())
.body(resource);
} catch (IOException e) {
return ResponseEntity.status(500).build();
}
}
}
集成knife4j方便测试,启动项目,访问 http://localhost:8888/doc.html 进行测试。
测试效果
根据下载的数据集训练生成模型
本地测试采用window cpu进行训练
下载一张测试图片进行分类
测试图片属于的鞋子类型
由此可见该图片所属的鞋类为 Shoes,概率最高。
后续使用CentOS 7虚拟机也可以正常运行(项目profile 切换到centos7打包即可)。
参见:
- https://docs.djl.ai/engines/pytorch/pytorch-engine/index.html
- https://docs.djl.ai/docs/demos/footwear_classification/index.html#train-the-footwear-classification-model
至此就大功告成了,对于DJL还不是很熟悉,对于训练速度提升不知还有什么更好的方案,不知是否还有最优解。欢迎交流。